zaydzuhri commited on
Commit
1e817be
·
verified ·
1 Parent(s): 8a3fb14

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. fla/layers/__init__.py +44 -0
  2. fla/layers/__pycache__/attn.cpython-312.pyc +0 -0
  3. fla/layers/__pycache__/forgetting_attn.cpython-312.pyc +0 -0
  4. fla/layers/__pycache__/gated_deltanet.cpython-312.pyc +0 -0
  5. fla/layers/__pycache__/gated_deltaproduct.cpython-312.pyc +0 -0
  6. fla/layers/__pycache__/gla.cpython-312.pyc +0 -0
  7. fla/layers/__pycache__/multiscale_retention.cpython-312.pyc +0 -0
  8. fla/layers/__pycache__/rwkv7.cpython-312.pyc +0 -0
  9. fla/layers/based.py +96 -0
  10. fla/layers/bitattn.py +192 -0
  11. fla/layers/hgrn.py +168 -0
  12. fla/layers/linear_attn.py +166 -0
  13. fla/layers/rwkv6.py +307 -0
  14. fla/layers/rwkv7.py +221 -0
  15. fla/layers/simple_gla.py +261 -0
  16. fla/models/__init__.py +53 -0
  17. fla/models/abc/__pycache__/configuration_abc.cpython-312.pyc +0 -0
  18. fla/models/bitnet/__init__.py +13 -0
  19. fla/models/delta_net/__init__.py +12 -0
  20. fla/models/delta_net/__pycache__/modeling_delta_net.cpython-312.pyc +0 -0
  21. fla/models/delta_net/modeling_delta_net.py +415 -0
  22. fla/models/forgetting_transformer/__pycache__/__init__.cpython-312.pyc +0 -0
  23. fla/models/forgetting_transformer/__pycache__/modeling_forgetting_transformer.cpython-312.pyc +0 -0
  24. fla/models/gated_deltanet/__init__.py +12 -0
  25. fla/models/gated_deltanet/__pycache__/__init__.cpython-312.pyc +0 -0
  26. fla/models/gated_deltanet/__pycache__/configuration_gated_deltanet.cpython-312.pyc +0 -0
  27. fla/models/gated_deltanet/__pycache__/modeling_gated_deltanet.cpython-312.pyc +0 -0
  28. fla/models/gla/__pycache__/__init__.cpython-312.pyc +0 -0
  29. fla/models/gsa/__init__.py +13 -0
  30. fla/models/gsa/modeling_gsa.py +420 -0
  31. fla/models/hgrn/__init__.py +13 -0
  32. fla/models/hgrn2/__pycache__/__init__.cpython-312.pyc +0 -0
  33. fla/models/lightnet/__pycache__/__init__.cpython-312.pyc +0 -0
  34. fla/models/lightnet/__pycache__/modeling_lightnet.cpython-312.pyc +0 -0
  35. fla/models/lightnet/modeling_lightnet.py +410 -0
  36. fla/models/linear_attn/__pycache__/__init__.cpython-312.pyc +0 -0
  37. fla/models/linear_attn/__pycache__/configuration_linear_attn.cpython-312.pyc +0 -0
  38. fla/models/mamba/configuration_mamba.py +166 -0
  39. fla/models/nsa/__init__.py +15 -0
  40. fla/models/nsa/modeling_nsa.py +398 -0
  41. fla/models/retnet/__init__.py +13 -0
  42. fla/models/retnet/configuration_retnet.py +92 -0
  43. fla/models/retnet/modeling_retnet.py +425 -0
  44. fla/models/rwkv6/__pycache__/__init__.cpython-312.pyc +0 -0
  45. fla/models/rwkv6/__pycache__/modeling_rwkv6.cpython-312.pyc +0 -0
  46. fla/models/rwkv7/__pycache__/__init__.cpython-312.pyc +0 -0
  47. fla/models/rwkv7/__pycache__/modeling_rwkv7.cpython-312.pyc +0 -0
  48. fla/models/rwkv7/modeling_rwkv7.py +505 -0
  49. fla/models/samba/__init__.py +13 -0
  50. fla/models/samba/__pycache__/__init__.cpython-312.pyc +0 -0
fla/layers/__init__.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from .abc import ABCAttention
5
+ from .attn import Attention
6
+ from .based import BasedLinearAttention
7
+ from .bitattn import BitAttention
8
+ from .delta_net import DeltaNet
9
+ from .forgetting_attn import ForgettingAttention
10
+ from .gated_deltanet import GatedDeltaNet
11
+ from .gated_deltaproduct import GatedDeltaProduct
12
+ from .gla import GatedLinearAttention
13
+ from .gsa import GatedSlotAttention
14
+ from .hgrn import HGRNAttention
15
+ from .hgrn2 import HGRN2Attention
16
+ from .lightnet import LightNetAttention
17
+ from .linear_attn import LinearAttention
18
+ from .multiscale_retention import MultiScaleRetention
19
+ from .nsa import NativeSparseAttention
20
+ from .rebased import ReBasedLinearAttention
21
+ from .rwkv6 import RWKV6Attention
22
+ from .rwkv7 import RWKV7Attention
23
+
24
+ __all__ = [
25
+ 'ABCAttention',
26
+ 'Attention',
27
+ 'BasedLinearAttention',
28
+ 'BitAttention',
29
+ 'DeltaNet',
30
+ 'ForgettingAttention',
31
+ 'GatedDeltaNet',
32
+ 'GatedDeltaProduct',
33
+ 'GatedLinearAttention',
34
+ 'GatedSlotAttention',
35
+ 'HGRNAttention',
36
+ 'HGRN2Attention',
37
+ 'LightNetAttention',
38
+ 'LinearAttention',
39
+ 'MultiScaleRetention',
40
+ 'NativeSparseAttention',
41
+ 'ReBasedLinearAttention',
42
+ 'RWKV6Attention',
43
+ 'RWKV7Attention',
44
+ ]
fla/layers/__pycache__/attn.cpython-312.pyc ADDED
Binary file (9.5 kB). View file
 
fla/layers/__pycache__/forgetting_attn.cpython-312.pyc ADDED
Binary file (5.3 kB). View file
 
fla/layers/__pycache__/gated_deltanet.cpython-312.pyc ADDED
Binary file (13.4 kB). View file
 
fla/layers/__pycache__/gated_deltaproduct.cpython-312.pyc ADDED
Binary file (14.4 kB). View file
 
fla/layers/__pycache__/gla.cpython-312.pyc ADDED
Binary file (13.3 kB). View file
 
fla/layers/__pycache__/multiscale_retention.cpython-312.pyc ADDED
Binary file (12.5 kB). View file
 
fla/layers/__pycache__/rwkv7.cpython-312.pyc ADDED
Binary file (10.5 kB). View file
 
fla/layers/based.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ """
5
+ Linear attention in Based.
6
+ https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/based.py
7
+ """
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from einops import rearrange
12
+
13
+ from fla.modules.feature_map import TaylorFeatureMap
14
+ from fla.ops.based import parallel_based
15
+ from fla.ops.linear_attn import chunk_linear_attn, fused_chunk_linear_attn
16
+
17
+
18
+ class BasedLinearAttention(nn.Module):
19
+
20
+ def __init__(
21
+ self,
22
+ hidden_size: int,
23
+ feature_dim: int = 16,
24
+ num_key_value_heads: int = 12,
25
+ num_heads: int = 12,
26
+ feature_name: str = "taylor_exp",
27
+ eps: float = 1e-12,
28
+ causal: bool = True,
29
+ mode: str = "parallel",
30
+ ):
31
+ super().__init__()
32
+
33
+ self.hidden_size = hidden_size
34
+ self.mode = mode
35
+ self.feature_name = feature_name
36
+ self.feature_dim = feature_dim
37
+ self.num_key_value_heads = num_key_value_heads
38
+ self.num_heads = num_heads
39
+ self.head_dim = self.hidden_size // self.num_key_value_heads
40
+ assert self.hidden_size % self.head_dim == 0
41
+ self.causal = causal
42
+
43
+ self.q_proj = nn.Linear(self.hidden_size, self.feature_dim * self.num_heads, bias=False)
44
+ self.k_proj = nn.Linear(self.hidden_size, self.feature_dim * self.num_heads, bias=False)
45
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
46
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
47
+ self.dropout = nn.Identity()
48
+ self.feature_map = TaylorFeatureMap(feature_dim)
49
+ self.eps = eps
50
+
51
+ def forward(self, hidden_states: torch.Tensor, **kwargs):
52
+ mode = self.mode
53
+ q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
54
+ q, k, v = map(lambda x: rearrange(x, "... (h d) -> ... h d", d=self.head_dim), [q, k, v])
55
+ if mode == "fused_chunk":
56
+ q, k = self.feature_map(q), self.feature_map(k)
57
+ o, _ = fused_chunk_linear_attn(q, k, v, normalize=True, scale=1, head_first=False)
58
+ elif mode == 'chunk':
59
+ q, k = self.feature_map(q), self.feature_map(k)
60
+ o, _ = chunk_linear_attn(q, k, v, normalize=True, scale=1, head_first=False)
61
+ elif mode == 'parallel':
62
+ assert q.shape[-1] <= 128
63
+ o = parallel_based(q, k, v, scale=1, use_norm=True, head_first=False)
64
+ o = rearrange(o, 'b t h d -> b t (h d)')
65
+ o = self.o_proj(o)
66
+ o = self.dropout(o)
67
+ return o
68
+
69
+ # https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/based.py#L119
70
+
71
+ def forward_reference(self, hidden_states: torch.Tensor, filters: torch.Tensor = None, *args, **kwargs):
72
+ """
73
+ x (torch.Tensor): tensor of shape (b, d, t)
74
+ y (torch.Tensor): tensor of shape (b, d, t)
75
+ """
76
+ # hidden_states = hidden_states.transpose(1, 2)
77
+ b, t, _ = hidden_states.size()
78
+ q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
79
+
80
+ q = q.view(b, t, self.num_heads, self.feature_dim).transpose(1, 2)
81
+ k = k.view(b, t, self.num_key_value_heads, self.feature_dim).transpose(1, 2)
82
+ v = v.view(b, t, self.num_key_value_heads, self.head_dim).transpose(1, 2)
83
+
84
+ # Linear attention
85
+ q, k = self.feature_map(q), self.feature_map(k)
86
+ q, k, v = q.unsqueeze(-2), k.unsqueeze(-2), v.unsqueeze(-1)
87
+
88
+ # Compute attention
89
+ if self.causal:
90
+ y = ((q * (k * v).cumsum(2)).sum(-1) / ((q * k.cumsum(2)).sum(-1) + self.eps))
91
+ else:
92
+ y = ((q * (k * v).sum(2, True)).sum(-1) / ((q * k.sum(2, True)).sum(-1) + self.eps))
93
+ y = rearrange(y, 'b h t d -> b t (h d)')
94
+ y = self.o_proj(y.to(hidden_states.dtype))
95
+ y = self.dropout(y)
96
+ return y.to(hidden_states.dtype)
fla/layers/bitattn.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Optional, Tuple
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import torch.utils.checkpoint
13
+ from einops import rearrange
14
+ from transformers.utils import logging
15
+
16
+ from fla.modules import RotaryEmbedding
17
+ from fla.modules.fused_bitlinear import FusedBitLinear
18
+
19
+ if TYPE_CHECKING:
20
+ from fla.models.utils import Cache
21
+
22
+ try:
23
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
24
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
25
+ except ImportError:
26
+ warnings.warn(
27
+ "Flash Attention is not installed. Please install it via `pip install flash-attn --no-build-isolation`",
28
+ category=ImportWarning
29
+ )
30
+ flash_attn_func = None
31
+
32
+ logger = logging.get_logger(__name__)
33
+
34
+
35
+ class BitAttention(nn.Module):
36
+
37
+ def __init__(
38
+ self,
39
+ hidden_size: int = 2048,
40
+ num_heads: int = 32,
41
+ num_kv_heads: Optional[int] = None,
42
+ window_size: Optional[int] = None,
43
+ rope_theta: Optional[float] = 10000.,
44
+ max_position_embeddings: Optional[int] = None,
45
+ norm_eps: float = 1e-5,
46
+ layer_idx: int = None
47
+ ):
48
+ super().__init__()
49
+
50
+ self.num_heads = num_heads
51
+ if num_kv_heads is None:
52
+ self.num_kv_heads = self.num_heads
53
+ else:
54
+ self.num_kv_heads = num_kv_heads
55
+ self.num_kv_groups = num_heads // self.num_kv_heads
56
+ self.hidden_size = hidden_size
57
+ self.head_dim = self.hidden_size // self.num_heads
58
+ self.kv_dim = self.num_kv_heads * self.head_dim
59
+ self.kv_dim = self.num_kv_heads * self.head_dim
60
+ self.window_size = window_size
61
+ self.rope_theta = rope_theta
62
+ self.max_position_embeddings = max_position_embeddings
63
+ self.layer_idx = layer_idx
64
+
65
+ self.q_proj = FusedBitLinear(self.hidden_size, self.hidden_size, bias=False)
66
+ self.k_proj = FusedBitLinear(self.hidden_size, self.kv_dim, bias=False)
67
+ self.v_proj = FusedBitLinear(self.hidden_size, self.kv_dim, bias=False)
68
+ self.o_proj = FusedBitLinear(self.hidden_size, self.hidden_size, bias=False)
69
+
70
+ self.rotary = RotaryEmbedding(dim=self.head_dim, base=self.rope_theta)
71
+
72
+ def forward(
73
+ self,
74
+ hidden_states: torch.Tensor,
75
+ attention_mask: Optional[torch.LongTensor] = None,
76
+ past_key_values: Optional[Cache] = None,
77
+ output_attentions: bool = False,
78
+ use_cache: bool = False,
79
+ **kwargs,
80
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
81
+ if attention_mask is not None:
82
+ assert len(attention_mask.shape) == 2, (
83
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
84
+ "for padding purposes (0 indicating padding). "
85
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
86
+ )
87
+
88
+ batch_size, q_len, _ = hidden_states.size()
89
+
90
+ q = rearrange(self.q_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim)
91
+ k = rearrange(self.k_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim)
92
+ v = rearrange(self.v_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim)
93
+
94
+ # equivalent to cu_seqlens in `flash_attn`
95
+ cu_seqlens = kwargs.get('cu_seqlens', None)
96
+
97
+ seqlen_offset, max_seqlen = 0, q_len
98
+ if past_key_values is not None:
99
+ seqlen_offset = past_key_values.get_seq_length(self.layer_idx)
100
+ max_seqlen = q.shape[1] + seqlen_offset
101
+
102
+ if attention_mask is not None:
103
+ # to deliminate the offsets of padding tokens
104
+ seqlen_offset = seqlen_offset + attention_mask.sum(-1) - attention_mask.shape[-1]
105
+ max_seqlen = q.shape[1] + max(seqlen_offset)
106
+
107
+ if self.max_position_embeddings is not None:
108
+ max_seqlen = max(max_seqlen, self.max_position_embeddings)
109
+ q, k = self.rotary(q, k, seqlen_offset=seqlen_offset, max_seqlen=max_seqlen, cu_seqlens=cu_seqlens)
110
+
111
+ if past_key_values is not None:
112
+ cache_has_content = past_key_values.get_seq_length(self.layer_idx) > 0
113
+ k_cached, v_cached = past_key_values.update(
114
+ attn_state=(k.flatten(-2, -1), v.flatten(-2, -1)),
115
+ layer_idx=self.layer_idx,
116
+ offset=q_len,
117
+ cache_kwargs=dict(window_size=self.window_size)
118
+ )['attn_state']
119
+ if cache_has_content:
120
+ k, v = k_cached, v_cached
121
+ k = rearrange(k, '... (h d) -> ... h d', d=self.head_dim)
122
+ v = rearrange(v, '... (h d) -> ... h d', d=self.head_dim)
123
+
124
+ if flash_attn_func is None:
125
+ raise ImportError("Please install Flash Attention via `pip install flash-attn --no-build-isolation` first")
126
+
127
+ # Contains at least one padding token in the sequence
128
+ if attention_mask is not None:
129
+ q, k, v, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(q, k, v, attention_mask, q_len)
130
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
131
+ max_seqlen_q, max_seqlen_k = max_seq_lens
132
+ o = flash_attn_varlen_func(
133
+ q, k, v,
134
+ cu_seqlens_q=cu_seqlens_q,
135
+ cu_seqlens_k=cu_seqlens_k,
136
+ max_seqlen_q=max_seqlen_q,
137
+ max_seqlen_k=max_seqlen_k,
138
+ causal=True,
139
+ window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0)
140
+ )
141
+ o = pad_input(o, indices_q, batch_size, q_len)
142
+ elif cu_seqlens is not None:
143
+ o = flash_attn_varlen_func(
144
+ q.squeeze(0), k.squeeze(0), v.squeeze(0),
145
+ cu_seqlens_q=cu_seqlens,
146
+ cu_seqlens_k=cu_seqlens,
147
+ max_seqlen_q=max_seqlen,
148
+ max_seqlen_k=max_seqlen,
149
+ causal=True,
150
+ window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0)
151
+ ).unsqueeze(0)
152
+ else:
153
+ o = flash_attn_func(
154
+ q, k, v,
155
+ causal=True,
156
+ window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0)
157
+ )
158
+ o = o.reshape(batch_size, q_len, -1)
159
+ o = self.o_proj(o)
160
+
161
+ if not output_attentions:
162
+ attentions = None
163
+
164
+ return o, attentions, past_key_values
165
+
166
+ def _upad_input(self, q, k, v, attention_mask, q_len):
167
+ batch_size, seq_len, num_key_value_heads, head_dim = k.shape
168
+ cache_mask = attention_mask[:, -seq_len:]
169
+ seqlens = cache_mask.sum(-1, dtype=torch.int32)
170
+ indices_k = torch.nonzero(cache_mask.flatten(), as_tuple=False).flatten()
171
+ max_seqlen_k = seqlens.max().item()
172
+ cu_seqlens_k = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0))
173
+
174
+ k = index_first_axis(k.reshape(batch_size * seq_len, num_key_value_heads, head_dim), indices_k)
175
+ v = index_first_axis(v.reshape(batch_size * seq_len, num_key_value_heads, head_dim), indices_k)
176
+ if q_len == seq_len:
177
+ q = index_first_axis(q.reshape(batch_size * seq_len, self.num_heads, head_dim), indices_k)
178
+ cu_seqlens_q = cu_seqlens_k
179
+ max_seqlen_q = max_seqlen_k
180
+ indices_q = indices_k
181
+ elif q_len == 1:
182
+ max_seqlen_q = 1
183
+ # There is a memcpy here, that is very bad.
184
+ cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=q.device)
185
+ indices_q = cu_seqlens_q[:-1]
186
+ q = q.squeeze(1)
187
+ else:
188
+ # The -q_len: slice assumes left padding.
189
+ attention_mask = attention_mask[:, -q_len:]
190
+ q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, attention_mask)
191
+
192
+ return q, k, v, indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k)
fla/layers/hgrn.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ # "Hierarchically Gated Recurrent Neural Network for Sequence Modeling" [https://arxiv.org/abs/2311.04823]
5
+
6
+ from __future__ import annotations
7
+
8
+ from typing import TYPE_CHECKING, Dict, Optional, Tuple
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+ from fla.modules import FusedRMSNormGated, ShortConvolution
15
+ from fla.modules.activations import swiglu
16
+ from fla.ops.hgrn import chunk_hgrn, fused_recurrent_hgrn
17
+
18
+ if TYPE_CHECKING:
19
+ from transformers.processing_utils import Unpack
20
+
21
+ from fla.models.utils import Cache
22
+
23
+
24
+ class HGRNAttention(nn.Module):
25
+
26
+ def __init__(
27
+ self,
28
+ mode: str = 'chunk',
29
+ hidden_size: int = 1024,
30
+ expand_ratio: Optional[int] = 1,
31
+ use_short_conv: bool = False,
32
+ conv_size: int = 4,
33
+ conv_bias: bool = False,
34
+ elementwise_affine: Optional[bool] = True,
35
+ norm_eps: float = 1e-5,
36
+ layer_idx: int = None
37
+ ) -> HGRNAttention:
38
+ super().__init__()
39
+
40
+ self.mode = mode
41
+ self.hidden_size = hidden_size
42
+ self.expand_ratio = expand_ratio
43
+ self.input_dim = int(hidden_size * expand_ratio)
44
+
45
+ self.use_short_conv = use_short_conv
46
+ self.conv_size = conv_size
47
+ self.conv_bias = conv_bias
48
+
49
+ self.layer_idx = layer_idx
50
+
51
+ assert mode in ['chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
52
+
53
+ self.i_proj = nn.Linear(hidden_size, self.input_dim, bias=False)
54
+ self.f_proj = nn.Linear(hidden_size, self.input_dim, bias=False)
55
+ self.g_proj = nn.Linear(hidden_size, self.input_dim, bias=False)
56
+
57
+ if use_short_conv:
58
+ self.conv_size = conv_size
59
+ self.q_conv1d = ShortConvolution(self.input_dim, conv_size, activation=None)
60
+ self.f_conv1d = ShortConvolution(self.input_dim, conv_size, activation=None)
61
+ self.i_conv1d = ShortConvolution(self.input_dim, conv_size, activation=None)
62
+
63
+ self.g_norm = FusedRMSNormGated(
64
+ hidden_size=self.input_dim,
65
+ elementwise_affine=elementwise_affine,
66
+ eps=norm_eps
67
+ )
68
+ self.o_proj = nn.Linear(self.input_dim, hidden_size, bias=False)
69
+
70
+ def forward(
71
+ self,
72
+ hidden_states: torch.Tensor,
73
+ attention_mask: Optional[torch.Tensor] = None,
74
+ past_key_values: Optional[Cache] = None,
75
+ use_cache: Optional[bool] = False,
76
+ output_attentions: Optional[bool] = False,
77
+ lower_bound: Optional[torch.Tensor] = None,
78
+ **kwargs: Unpack[Dict]
79
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
80
+ if attention_mask is not None:
81
+ assert len(attention_mask.shape) == 2, (
82
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
83
+ "for padding purposes (0 indicating padding). "
84
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
85
+ )
86
+
87
+ # launching the triton kernel for just one token will actually be slower
88
+ mode = 'fused_recurrent' if not self.training and hidden_states.shape[1] <= 64 else self.mode
89
+
90
+ last_state = None
91
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
92
+ last_state = past_key_values[self.layer_idx]
93
+
94
+ cu_seqlens = kwargs.get('cu_seqlens', None)
95
+ if self.use_short_conv:
96
+ conv_state_i, conv_state_f = None, None
97
+ if last_state is not None:
98
+ conv_state_i, conv_state_f = last_state['conv_state']
99
+ conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
100
+ i, conv_state_i = self.i_conv1d(
101
+ x=self.i_proj(hidden_states),
102
+ mask=conv_mask,
103
+ cache=conv_state_i,
104
+ output_final_state=use_cache,
105
+ cu_seqlens=cu_seqlens
106
+ )
107
+ f, conv_state_f = self.f_conv1d(
108
+ x=self.f_proj(hidden_states),
109
+ mask=conv_mask,
110
+ cache=conv_state_f,
111
+ output_final_state=use_cache,
112
+ cu_seqlens=cu_seqlens
113
+ )
114
+ else:
115
+ i = self.i_proj(hidden_states)
116
+ f = self.f_proj(hidden_states)
117
+
118
+ # the lower bound for the first layer is zero
119
+ if lower_bound is None or self.layer_idx == 0:
120
+ i, f = swiglu(i, 1 - f.sigmoid()), F.logsigmoid(f)
121
+ else:
122
+ g = lower_bound + (1 - lower_bound) * f.sigmoid()
123
+ i, f = swiglu(i, 1 - g), g.log()
124
+
125
+ # dealing with left-padding
126
+ if attention_mask is not None:
127
+ i = i.mul_(attention_mask[:, -i.shape[-2]:, None])
128
+
129
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
130
+ if mode == 'chunk':
131
+ if cu_seqlens is not None:
132
+ raise NotImplementedError("Chunk mode does not support variable-length sequences.")
133
+ o, recurrent_state = chunk_hgrn(
134
+ x=i,
135
+ g=f,
136
+ initial_state=recurrent_state,
137
+ output_final_state=use_cache,
138
+ )
139
+ elif mode == 'fused_recurrent':
140
+ o, recurrent_state = fused_recurrent_hgrn(
141
+ x=i,
142
+ g=f,
143
+ initial_state=recurrent_state,
144
+ output_final_state=use_cache,
145
+ cu_seqlens=cu_seqlens
146
+ )
147
+ else:
148
+ raise NotImplementedError(f"Not supported mode `{mode}`.")
149
+
150
+ if past_key_values is not None:
151
+ past_key_values.update(
152
+ recurrent_state=recurrent_state,
153
+ conv_state=(conv_state_i, conv_state_f) if self.use_short_conv else None,
154
+ layer_idx=self.layer_idx,
155
+ offset=i.shape[2]
156
+ )
157
+
158
+ o = self.g_norm(o, self.g_proj(hidden_states))
159
+ o = self.o_proj(o)
160
+
161
+ return o, None, past_key_values
162
+
163
+ def state_size(self, **kwargs) -> int:
164
+ state_size = self.hidden_size
165
+ for module in self.children():
166
+ if isinstance(module, ShortConvolution):
167
+ state_size += module.state_size
168
+ return state_size
fla/layers/linear_attn.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from einops import rearrange, repeat
10
+
11
+ from fla.modules import RMSNorm
12
+ from fla.modules.feature_map import DPFPFeatureMap, HadamardFeatureMap, HedgehogFeatureMap, T2RFeatureMap
13
+ from fla.ops.linear_attn import chunk_linear_attn, fused_chunk_linear_attn, fused_recurrent_linear_attn
14
+
15
+
16
+ class LinearAttention(nn.Module):
17
+
18
+ def __init__(
19
+ self,
20
+ mode: str = 'chunk',
21
+ hidden_size: str = 1024,
22
+ expand_k: int = 1.0,
23
+ expand_v: int = 1.0,
24
+ num_heads: int = 8,
25
+ num_kv_heads: Optional[int] = None,
26
+ feature_map: str = 'elementwise_product',
27
+ tie_feature_map_qk: bool = False,
28
+ output_norm: str = 'rmsnorm',
29
+ norm_q: bool = False,
30
+ norm_k: bool = False,
31
+ do_feature_map_norm: bool = False,
32
+ elementwise_affine: bool = True,
33
+ norm_eps: float = 1e-5,
34
+ **kwargs
35
+ ):
36
+ super().__init__()
37
+
38
+ self.hidden_size = hidden_size
39
+ self.mode = mode
40
+ self.num_heads = num_heads
41
+ self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
42
+ self.num_kv_groups = self.num_heads // self.num_kv_heads
43
+ self.key_dim = int(hidden_size * expand_k)
44
+ self.value_dim = int(hidden_size * expand_v)
45
+ self.key_dim_per_group = self.key_dim // self.num_kv_groups
46
+ self.value_dim_per_group = self.value_dim // self.num_kv_groups
47
+
48
+ assert mode in ['chunk', 'fused_chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
49
+ assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
50
+ assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
51
+
52
+ self.head_k_dim = self.key_dim // num_heads
53
+ self.head_v_dim = self.value_dim // num_heads
54
+ self.do_feature_map_norm = do_feature_map_norm
55
+
56
+ if feature_map == 'hedgehog':
57
+ if tie_feature_map_qk:
58
+ self.feature_map_q = self.feature_map_k = HedgehogFeatureMap(head_dim=self.head_k_dim)
59
+ else:
60
+ self.feature_map_q = HedgehogFeatureMap(head_dim=self.head_k_dim)
61
+ self.feature_map_k = HedgehogFeatureMap(head_dim=self.head_k_dim)
62
+
63
+ elif feature_map == 't2r':
64
+ if tie_feature_map_qk:
65
+ self.feature_map_q = self.feature_map_k = T2RFeatureMap(head_dim=self.head_k_dim)
66
+ else:
67
+ self.feature_map_q = T2RFeatureMap(head_dim=self.head_k_dim)
68
+ self.feature_map_k = T2RFeatureMap(head_dim=self.head_k_dim)
69
+
70
+ elif feature_map == 'elementwise_product':
71
+ if tie_feature_map_qk:
72
+ self.feature_map_q = self.feature_map_k = HadamardFeatureMap(head_dim=self.head_k_dim)
73
+ else:
74
+ self.feature_map_q = HadamardFeatureMap(head_dim=self.head_k_dim)
75
+ self.feature_map_k = HadamardFeatureMap(head_dim=self.head_k_dim)
76
+
77
+ elif feature_map == 'dpfp':
78
+ self.feature_map_q = DPFPFeatureMap(head_dim=self.head_k_dim)
79
+ self.feature_map_k = DPFPFeatureMap(head_dim=self.head_k_dim)
80
+
81
+ elif feature_map == 'elu':
82
+ def elu(x):
83
+ return F.elu(x) + 1
84
+ self.feature_map_q = elu
85
+ self.feature_map_k = elu
86
+
87
+ elif feature_map == 'relu':
88
+ self.feature_map_q = nn.ReLU()
89
+ self.feature_map_k = nn.ReLU()
90
+
91
+ elif feature_map == 'identity':
92
+ self.feature_map_q = nn.Identity()
93
+ self.feature_map_k = nn.Identity()
94
+ else:
95
+ raise NotImplementedError(f"Not supported feature map `{feature_map}`.")
96
+
97
+ self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
98
+ self.k_proj = nn.Linear(hidden_size, self.key_dim_per_group, bias=False)
99
+ self.v_proj = nn.Linear(hidden_size, self.value_dim_per_group, bias=False)
100
+
101
+ if output_norm == 'rmsnorm':
102
+ self.norm = RMSNorm(hidden_size=self.head_v_dim, elementwise_affine=elementwise_affine, eps=norm_eps)
103
+ elif output_norm == 'identity':
104
+ self.norm = nn.Identity()
105
+ else:
106
+ raise NotImplementedError(f"Not supported output norm `{output_norm}`.")
107
+
108
+ self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
109
+
110
+ self.norm_q = norm_q
111
+ self.norm_k = norm_k
112
+
113
+ def forward(
114
+ self,
115
+ hidden_states: torch.Tensor,
116
+ **kwargs
117
+ ) -> torch.Tensor:
118
+ mode = self.mode
119
+ q = self.q_proj(hidden_states)
120
+ k = self.k_proj(hidden_states)
121
+ v = self.v_proj(hidden_states)
122
+
123
+ q = rearrange(q, '... (h d) -> ... h d', d=self.head_k_dim)
124
+ if self.num_kv_groups > 1:
125
+ k = repeat(k, '... (h d) -> ... (h g) d', d=self.head_k_dim, g=self.num_kv_groups)
126
+ v = repeat(v, '... (h d) -> ... (h g) d', d=self.head_v_dim, g=self.num_kv_groups)
127
+ else:
128
+ k = rearrange(k, '... (h d) -> ... h d', d=self.head_k_dim)
129
+ v = rearrange(v, '... (h d) -> ... h d', d=self.head_v_dim)
130
+
131
+ q = self.feature_map_q(q)
132
+ k = self.feature_map_k(k)
133
+
134
+ if self.norm_q:
135
+ q = q / (q.sum(-1, True) + 1e-4)
136
+ if self.norm_k:
137
+ k = k / (k.sum(-1, True) + 1e-4)
138
+
139
+ if mode == 'chunk':
140
+ o, final_state = chunk_linear_attn(
141
+ q=q,
142
+ k=k,
143
+ v=v,
144
+ normalize=self.do_feature_map_norm,
145
+ head_first=False
146
+ )
147
+ elif mode == 'fused_chunk':
148
+ o, final_state = fused_chunk_linear_attn(
149
+ q=q,
150
+ k=k,
151
+ v=v,
152
+ normalize=self.do_feature_map_norm,
153
+ )
154
+ elif mode == 'fused_recurrent':
155
+ o, final_state = fused_recurrent_linear_attn(
156
+ q=q,
157
+ k=k,
158
+ v=v,
159
+ normalize=self.do_feature_map_norm,
160
+ )
161
+ else:
162
+ raise NotImplementedError
163
+ o = self.norm(o)
164
+ o = rearrange(o, '... h d -> ... (h d)')
165
+ o = self.o_proj(o)
166
+ return o
fla/layers/rwkv6.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ # "Eagle and Finch: RWKV with Matrix-Valued States and Dynamic Recurrence"[https://arxiv.org/abs/2404.05892]
5
+
6
+ from __future__ import annotations
7
+
8
+ from typing import TYPE_CHECKING, Optional, Tuple
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ from einops import rearrange
13
+
14
+ from fla.modules import GroupNorm
15
+ from fla.modules.activations import ACT2FN
16
+ from fla.ops.rwkv6 import chunk_rwkv6, fused_recurrent_rwkv6
17
+
18
+ if TYPE_CHECKING:
19
+ from fla.models.utils import Cache
20
+
21
+
22
+ class RWKV6Attention(nn.Module):
23
+
24
+ def __init__(
25
+ self,
26
+ mode: str = 'chunk',
27
+ hidden_size: int = 1024,
28
+ expand_k: float = 0.5,
29
+ expand_v: float = 1.0,
30
+ num_heads: int = 4,
31
+ gate_fn: str = 'swish',
32
+ proj_low_rank_dim: int = 32,
33
+ gate_low_rank_dim: int = 64,
34
+ fuse_norm: bool = True,
35
+ elementwise_affine: Optional[bool] = True,
36
+ norm_eps: float = 1e-5,
37
+ layer_idx: int = None,
38
+ **kwargs
39
+ ) -> RWKV6Attention:
40
+ super().__init__()
41
+
42
+ self.mode = mode
43
+ self.hidden_size = hidden_size
44
+ self.expand_k = expand_k
45
+ self.expand_v = expand_v
46
+ self.num_heads = num_heads
47
+ self.proj_low_rank_dim = proj_low_rank_dim
48
+ self.gate_low_rank_dim = gate_low_rank_dim
49
+
50
+ self.key_dim = int(hidden_size * expand_k)
51
+ self.value_dim = int(hidden_size * expand_v)
52
+ self.layer_idx = layer_idx
53
+
54
+ assert mode in ['chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
55
+ assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
56
+ assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
57
+
58
+ self.head_k_dim = self.key_dim // num_heads
59
+ self.head_v_dim = self.value_dim // num_heads
60
+
61
+ self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
62
+ self.x_proj = nn.Sequential(
63
+ LerpLinear(hidden_size, proj_low_rank_dim * 5),
64
+ nn.Tanh(),
65
+ nn.Linear(proj_low_rank_dim * 5, hidden_size, bias=False)
66
+ )
67
+ self.x_bias = nn.Parameter(torch.zeros(5, hidden_size))
68
+
69
+ self.r_proj = DDLerpLinear(hidden_size, self.key_dim)
70
+ self.w_proj = DDLerpLinear(hidden_size, self.key_dim, low_rank_dim=gate_low_rank_dim)
71
+ self.k_proj = DDLerpLinear(hidden_size, self.key_dim)
72
+ self.v_proj = DDLerpLinear(hidden_size, self.value_dim)
73
+ self.g_proj = DDLerpLinear(hidden_size, self.value_dim)
74
+ self.bonus = nn.Parameter(torch.zeros(num_heads, self.head_k_dim))
75
+
76
+ # TODO: fuse GroupNorm and output gate
77
+ self.g_norm = GroupNorm(self.num_heads, self.value_dim, elementwise_affine=elementwise_affine, bias=True, eps=norm_eps)
78
+ self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
79
+ self.gate_fn = ACT2FN[gate_fn]
80
+
81
+ self.apply(self._initialize_weights)
82
+
83
+ def _initialize_weights(self, module: nn.Module):
84
+ if getattr(module, "_is_hf_initialized", False):
85
+ return
86
+ if isinstance(module, nn.Linear):
87
+ nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
88
+ if module.bias is not None:
89
+ nn.init.zeros_(module.bias)
90
+ if isinstance(module, nn.Parameter):
91
+ nn.init.xavier_uniform_(module, gain=2 ** -2.5)
92
+ module._is_hf_initialized = True
93
+
94
+ def forward(
95
+ self,
96
+ hidden_states: torch.Tensor,
97
+ attention_mask: Optional[torch.Tensor] = None,
98
+ past_key_values: Optional[Cache] = None,
99
+ use_cache: Optional[bool] = False,
100
+ output_attentions: Optional[bool] = False,
101
+ **kwargs
102
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
103
+ if attention_mask is not None:
104
+ assert len(attention_mask.shape) == 2, (
105
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
106
+ "for padding purposes (0 indicating padding). "
107
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
108
+ )
109
+
110
+ batch_size, seq_len, hidden_size = hidden_states.shape
111
+ # launching the triton kernel for just one token will actually be slower
112
+ mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
113
+
114
+ last_state = None
115
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
116
+ last_state = past_key_values[self.layer_idx]
117
+
118
+ if attention_mask is not None:
119
+ hidden_states = hidden_states.mul_(attention_mask[:, -hidden_states.shape[-2]:, None])
120
+ if hidden_states.shape[1] == 1 and last_state is not None:
121
+ shifted = last_state['conv_state'].unsqueeze(1)
122
+ else:
123
+ shifted = self.time_shift(hidden_states)
124
+ if last_state is not None:
125
+ shifted[:, 0] = last_state['conv_state']
126
+
127
+ delta = shifted - hidden_states
128
+ x = self.x_proj[0](hidden_states, delta).view(batch_size, seq_len, -1, self.proj_low_rank_dim)
129
+ x = torch.einsum('b t n r, h n r-> b t n h', self.x_proj[1](x), self.x_proj[2].weight.view(hidden_size, 5, -1))
130
+
131
+ r, w, k, v, g = x.add_(self.x_bias).unbind(-2)
132
+ r = self.r_proj(hidden_states, r, delta)
133
+ w = self.w_proj(hidden_states, w, delta)
134
+ k = self.k_proj(hidden_states, k, delta)
135
+ v = self.v_proj(hidden_states, v, delta)
136
+ g = self.g_proj(hidden_states, g, delta)
137
+
138
+ # dealing with left-padding
139
+ if attention_mask is not None:
140
+ v = v.mul_(attention_mask[:, -v.shape[-2]:, None])
141
+ r, w, k = map(lambda x: rearrange(x, 'b t (h d) -> b t h d', d=self.head_k_dim), (r, w, k))
142
+ v = rearrange(v, 'b t (h d) -> b t h d', d=self.head_v_dim)
143
+ w = -torch.exp(w)
144
+ u = self.bonus
145
+
146
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
147
+ cu_seqlens = kwargs.get('cu_seqlens', None)
148
+ if mode == 'fused_recurrent':
149
+ o, recurrent_state = fused_recurrent_rwkv6(
150
+ r=r,
151
+ k=k,
152
+ v=v,
153
+ w=w,
154
+ u=u,
155
+ scale=1.,
156
+ initial_state=recurrent_state,
157
+ output_final_state=use_cache,
158
+ cu_seqlens=cu_seqlens,
159
+ head_first=False
160
+ )
161
+ elif mode == 'chunk':
162
+ o, recurrent_state = chunk_rwkv6(
163
+ q=r,
164
+ k=k,
165
+ v=v,
166
+ g=w,
167
+ u=u,
168
+ scale=1.,
169
+ initial_state=recurrent_state,
170
+ output_final_state=use_cache,
171
+ cu_seqlens=cu_seqlens,
172
+ head_first=False
173
+ )
174
+ else:
175
+ raise NotImplementedError(f"Not supported mode `{mode}`.")
176
+
177
+ if past_key_values is not None:
178
+ past_key_values.update(
179
+ recurrent_state=recurrent_state,
180
+ conv_state=hidden_states[:, -1],
181
+ layer_idx=self.layer_idx,
182
+ offset=r.shape[2]
183
+ )
184
+
185
+ o = self.g_norm(rearrange(o, '... h d -> ... (h d)')) * self.gate_fn(g)
186
+ o = self.o_proj(o)
187
+
188
+ return o, None, past_key_values
189
+
190
+
191
+ class LoRA(nn.Module):
192
+
193
+ def __init__(
194
+ self,
195
+ input_dim: int,
196
+ output_dim: int,
197
+ low_rank_dim: int,
198
+ bias: Optional[bool] = True,
199
+ activation: Optional[str] = 'tanh'
200
+ ):
201
+ super().__init__()
202
+
203
+ self.input_dim = input_dim
204
+ self.output_dim = output_dim
205
+ self.low_rank_dim = low_rank_dim
206
+ self.bias = bias
207
+
208
+ if activation is None:
209
+ self.activation = nn.Identity()
210
+ elif activation == 'sigmoid':
211
+ self.activation = nn.Sigmoid()
212
+ elif activation == 'tanh':
213
+ self.activation = nn.Tanh()
214
+ elif activation == 'relu':
215
+ self.activation = nn.ReLU()
216
+ else:
217
+ raise ValueError(f"Not supported activation `{activation}`.")
218
+
219
+ self.lora = nn.Sequential(
220
+ nn.Linear(input_dim, low_rank_dim, bias=False),
221
+ self.activation,
222
+ nn.Linear(low_rank_dim, output_dim, bias=bias)
223
+ )
224
+
225
+ def __repr__(self) -> str:
226
+ s = f"{self.__class__.__name__}("
227
+ s += f"input_dim={self.input_dim}, low_rank_dim={self.low_rank_dim}, output_dim={self.output_dim}"
228
+ if not self.bias:
229
+ s += f", bias={self.bias}"
230
+ s += ")"
231
+ return s
232
+
233
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
234
+ return self.lora(x)
235
+
236
+
237
+ class LerpLinear(nn.Module):
238
+
239
+ def __init__(
240
+ self,
241
+ input_dim: int,
242
+ output_dim: int,
243
+ low_rank_dim: Optional[int] = None
244
+ ):
245
+ super().__init__()
246
+
247
+ self.input_dim = input_dim
248
+ self.output_dim = output_dim
249
+ self.low_rank_dim = low_rank_dim
250
+
251
+ self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
252
+ if low_rank_dim is None:
253
+ self.linear = nn.Linear(input_dim, output_dim, bias=False)
254
+ else:
255
+ self.linear = LoRA(input_dim, output_dim, low_rank_dim)
256
+ self.mu = nn.Parameter(torch.zeros(input_dim))
257
+
258
+ def __repr__(self) -> str:
259
+ s = f"{self.__class__.__name__}({self.input_dim}, {self.output_dim}"
260
+ if self.low_rank_dim is not None:
261
+ s += f", low_rank_dim={self.low_rank_dim}"
262
+ s += ")"
263
+ return s
264
+
265
+ def forward(self, x: torch.Tensor, delta: Optional[torch.Tensor] = None) -> torch.Tensor:
266
+ if delta is None:
267
+ shifted = self.time_shift(x)
268
+ if len(shifted.shape) == 2:
269
+ shifted = shifted.unsqueeze(1)
270
+ delta = shifted - x
271
+ return self.linear(x + delta * self.mu)
272
+
273
+
274
+ class DDLerpLinear(nn.Module):
275
+
276
+ def __init__(
277
+ self,
278
+ input_dim: int,
279
+ output_dim: int,
280
+ low_rank_dim: Optional[int] = None
281
+ ):
282
+ super().__init__()
283
+
284
+ self.input_dim = input_dim
285
+ self.output_dim = output_dim
286
+ self.low_rank_dim = low_rank_dim
287
+
288
+ self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
289
+ if low_rank_dim is None:
290
+ self.linear = nn.Linear(input_dim, output_dim, bias=False)
291
+ else:
292
+ self.linear = LoRA(input_dim, output_dim, low_rank_dim)
293
+
294
+ def __repr__(self) -> str:
295
+ s = f"{self.__class__.__name__}({self.input_dim}, {self.output_dim}"
296
+ if self.low_rank_dim is not None:
297
+ s += f", low_rank_dim={self.low_rank_dim}"
298
+ s += ")"
299
+ return s
300
+
301
+ def forward(self, x: torch.Tensor, mu: torch.Tensor, delta: Optional[torch.Tensor] = None) -> torch.Tensor:
302
+ if delta is None:
303
+ shifted = self.time_shift(x)
304
+ if len(shifted.shape) == 2:
305
+ shifted = shifted.unsqueeze(1)
306
+ delta = shifted - x
307
+ return self.linear(x + delta * mu)
fla/layers/rwkv7.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ from typing import TYPE_CHECKING, Optional, Tuple
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from einops import rearrange
11
+ from torch.nn import functional as F
12
+
13
+ from fla.layers.rwkv6 import LoRA
14
+ from fla.modules import GroupNorm
15
+ from fla.modules.l2norm import l2_norm
16
+ from fla.ops.rwkv7 import chunk_rwkv7, fused_recurrent_rwkv7
17
+
18
+ if TYPE_CHECKING:
19
+ from fla.models.utils import Cache
20
+
21
+
22
+ class RWKV7Attention(nn.Module):
23
+
24
+ def __init__(
25
+ self,
26
+ mode: str = 'chunk',
27
+ hidden_size: int = 1024,
28
+ head_dim: Optional[int] = 64,
29
+ num_heads: Optional[int] = None,
30
+ decay_low_rank_dim: int = 64,
31
+ gate_low_rank_dim: int = 128,
32
+ a_low_rank_dim: int = 64,
33
+ v_low_rank_dim: int = 16,
34
+ elementwise_affine: Optional[bool] = True,
35
+ norm_eps: float = 1e-5,
36
+ layer_idx: int = None,
37
+ fuse_norm: bool = False,
38
+ value_dim: int = None,
39
+ **kwargs
40
+ ) -> RWKV7Attention:
41
+ super().__init__()
42
+
43
+ self.mode = mode
44
+ assert mode in ['chunk', 'fused_recurrent'], f"Not supported mode `{mode}`."
45
+ self.hidden_size = hidden_size
46
+
47
+ self.key_dim = hidden_size
48
+ self.value_dim = value_dim if value_dim is not None else hidden_size
49
+ if head_dim is None and num_heads is None:
50
+ raise ValueError("Either `head_dim` or `num_heads` must be specified.")
51
+ elif head_dim is not None:
52
+ self.head_dim = head_dim
53
+ self.num_heads = int(hidden_size // head_dim)
54
+ elif num_heads is not None:
55
+ self.head_dim = int(hidden_size // num_heads)
56
+ self.num_heads = num_heads
57
+ self.head_v_dim = int(self.value_dim // self.num_heads)
58
+
59
+ self.decay_low_rank_dim = decay_low_rank_dim
60
+ self.gate_low_rank_dim = gate_low_rank_dim
61
+ self.a_low_rank_dim = a_low_rank_dim
62
+ self.v_low_rank_dim = v_low_rank_dim
63
+ self.layer_idx = layer_idx
64
+ self.fuse_norm = fuse_norm
65
+
66
+ self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
67
+
68
+ self.x_x = nn.Parameter(torch.zeros(6, hidden_size))
69
+
70
+ self.k_k = nn.Parameter(torch.zeros(self.key_dim))
71
+ self.k_a = nn.Parameter(torch.zeros(self.key_dim))
72
+ self.r_k = nn.Parameter(torch.zeros(self.num_heads, self.head_dim))
73
+
74
+ self.r_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
75
+ self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
76
+ self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
77
+ self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
78
+
79
+ self.w_lora = LoRA(hidden_size, self.key_dim, low_rank_dim=decay_low_rank_dim, activation='tanh')
80
+ if self.layer_idx != 0:
81
+ self.v_lora = LoRA(hidden_size, self.value_dim, low_rank_dim=v_low_rank_dim, activation=None)
82
+ self.a_lora = LoRA(hidden_size, self.key_dim, low_rank_dim=a_low_rank_dim, activation=None)
83
+ self.g_lora = LoRA(hidden_size, self.value_dim, low_rank_dim=gate_low_rank_dim, activation='sigmoid', bias=False)
84
+
85
+ if self.fuse_norm:
86
+ self.g_norm = GroupNorm(
87
+ num_groups=self.num_heads,
88
+ hidden_size=self.value_dim,
89
+ elementwise_affine=elementwise_affine,
90
+ eps=self.head_dim*norm_eps,
91
+ bias=True,
92
+ )
93
+ else:
94
+ self.g_norm = nn.GroupNorm(
95
+ num_groups=self.num_heads,
96
+ num_channels=self.value_dim,
97
+ eps=self.head_dim*norm_eps,
98
+ affine=elementwise_affine
99
+ )
100
+
101
+ self.apply(self._initialize_weights)
102
+
103
+ def _initialize_weights(self, module: nn.Module):
104
+ if getattr(module, "_is_hf_initialized", False):
105
+ return
106
+ if isinstance(module, nn.Linear):
107
+ nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
108
+ if module.bias is not None:
109
+ nn.init.zeros_(module.bias)
110
+ if isinstance(module, nn.Parameter):
111
+ nn.init.xavier_uniform_(module, gain=2 ** -2.5)
112
+ module._is_hf_initialized = True
113
+
114
+ def forward(
115
+ self,
116
+ hidden_states: torch.Tensor,
117
+ attention_mask: Optional[torch.Tensor] = None,
118
+ past_key_values: Optional[Cache] = None,
119
+ use_cache: Optional[bool] = False,
120
+ output_attentions: Optional[bool] = False,
121
+ v_first: torch.Tensor = None,
122
+ **kwargs
123
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
124
+ if attention_mask is not None:
125
+ assert len(attention_mask.shape) == 2, (
126
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
127
+ "for padding purposes (0 indicating padding). "
128
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
129
+ )
130
+
131
+ batch_size, seq_len, _ = hidden_states.shape
132
+
133
+ if self.training:
134
+ # if training, use chunk mode no matter how short the sequence is
135
+ mode = 'chunk'
136
+ else:
137
+ # launching the triton kernel for just one token will actually be slower
138
+ mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
139
+
140
+ last_state = None
141
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
142
+ last_state = past_key_values[self.layer_idx]
143
+
144
+ if attention_mask is not None:
145
+ hidden_states = hidden_states.mul(attention_mask[:, -hidden_states.shape[-2]:, None])
146
+ if hidden_states.shape[1] == 1 and last_state is not None:
147
+ shifted = last_state['conv_state'].unsqueeze(1)
148
+ else:
149
+ shifted = self.time_shift(hidden_states)
150
+ if last_state is not None:
151
+ shifted[:, 0] = last_state['conv_state']
152
+
153
+ # [batch_size, seq_len, hidden_size]
154
+ delta = shifted - hidden_states
155
+ xr, xw, xk, xv, xa, xg = hidden_states.addcmul(delta, self.x_x.view(6, 1, 1, -1)).unbind(0)
156
+
157
+ r = self.r_proj(xr)
158
+ # -math.exp(-0.5) = -0.6065306597126334
159
+ # I think .to(torch.float) is unnecessary here, since we calculate lora in bloat16
160
+ # when we apply sigmoid, bf16 input will not have numerical issue
161
+ # FIXME: check if we can remove .to(torch.float)
162
+ w = -0.6065306597126334 * self.w_lora(xw).to(torch.float).sigmoid()
163
+
164
+ k = self.k_proj(xk)
165
+ v = self.v_proj(xv)
166
+
167
+ if self.layer_idx == 0:
168
+ v_first = v
169
+ else:
170
+ v = torch.lerp(v, v_first, self.v_lora(xv).sigmoid())
171
+ a = self.a_lora(xa).sigmoid()
172
+ g = self.g_lora(xg)
173
+
174
+ if self.fuse_norm:
175
+ kk = l2_norm(rearrange(k * self.k_k, 'b t (h d) -> b t h d', d=self.head_dim))
176
+ else:
177
+ kk = F.normalize(rearrange(k * self.k_k, 'b t (h d) -> b t h d', d=self.head_dim), dim=-1, p=2.0)
178
+
179
+ k = k.addcmul(k * (a - 1), self.k_a)
180
+
181
+ # dealing with left-padding
182
+ if attention_mask is not None:
183
+ v = v * attention_mask[:, -v.shape[-2]:, None]
184
+ r, w, k, a = map(lambda x: rearrange(x, 'b t (h d) -> b t h d', d=self.head_dim), (r, w, k, a))
185
+ v = rearrange(v, 'b t (h d) -> b t h d', d=self.head_v_dim)
186
+
187
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
188
+
189
+ rwkv7_fn = chunk_rwkv7 if mode == 'chunk' else fused_recurrent_rwkv7
190
+ cu_seqlens = kwargs.get('cu_seqlens', None)
191
+ o, recurrent_state = rwkv7_fn(
192
+ r=r,
193
+ w=w,
194
+ k=k,
195
+ v=v,
196
+ a=-kk,
197
+ b=kk * a,
198
+ scale=1.,
199
+ initial_state=recurrent_state,
200
+ output_final_state=use_cache,
201
+ cu_seqlens=cu_seqlens,
202
+ head_first=False
203
+ )
204
+
205
+ if past_key_values is not None:
206
+ past_key_values.update(
207
+ recurrent_state=recurrent_state,
208
+ conv_state=hidden_states[:, -1],
209
+ layer_idx=self.layer_idx,
210
+ offset=r.shape[1]
211
+ )
212
+
213
+ if self.fuse_norm:
214
+ o = self.g_norm(rearrange(o, '... h d -> ... (h d)'))
215
+ else:
216
+ o = self.g_norm(rearrange(o, 'b t h d -> (b t) (h d)')).view(batch_size, seq_len, -1)
217
+
218
+ o = o + ((r * k * self.r_k).sum(-1, keepdim=True) * v).view(batch_size, seq_len, -1)
219
+ o = self.o_proj(o * g)
220
+
221
+ return o, None, past_key_values, v_first
fla/layers/simple_gla.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ from typing import TYPE_CHECKING, Optional, Tuple
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from einops import rearrange, repeat
12
+
13
+ from fla.modules import FusedRMSNormGated, RMSNorm, ShortConvolution
14
+ from fla.modules.activations import ACT2FN
15
+ from fla.ops.simple_gla import chunk_simple_gla, fused_recurrent_simple_gla
16
+
17
+ if TYPE_CHECKING:
18
+ from fla.models.utils import Cache
19
+
20
+
21
+ class SimpleGatedLinearAttention(nn.Module):
22
+ r"""
23
+ The layer implementaion for [Gated Linear Attention Transformers with Hardware-Efficient Training](https://arxiv.org/abs/2312.06635). # noqa
24
+ This layer calls the simplified GLA kernel in which the gating is head-wise instead of elementwise.
25
+
26
+ Args:
27
+ mode (str, Optional):
28
+ Which GLA kernel to use.
29
+ Currently available: `chunk`.
30
+ Default: `chunk`.
31
+ hidden_size (int, Optional):
32
+ The hidden size of the input. Default: 1024.
33
+ expand_k (float, Optional):
34
+ The expansion ratio for the key dim. Default: 1.0.
35
+ expand_v (float, Optional):
36
+ The expansion ratio for the value dim. Default: 1.0.
37
+ num_heads (int, Optional):
38
+ The number of heads. Default: 4.
39
+ num_kv_heads (int, Optional):
40
+ The number of key/value heads, used for MQA. Default: None.
41
+ feature_map (str, Optional):
42
+ Feature map function applied to queries/keys. Default: None.
43
+ use_short_conv (bool, Optional):
44
+ Whether to use short convolutions. Default: `False`.
45
+ conv_size (int, Optional):
46
+ The kernel size of the short convolution, only used when `use_short_conv` is `True`. Default: 4.
47
+ conv_bias (bool, Optional):
48
+ Whether to use bias in the short convolution, only used when `use_short_conv` is `True`. Default: `False`.
49
+ gate_fn (str, Optional):
50
+ The activation function for the output gate. Default: `swish`.
51
+ elementwise_affine (bool, Optional):
52
+ If `True`, applies elementwise affine to LayerNorm with learnable parameters. Default: `True`.
53
+ norm_eps (float, Optional):
54
+ The epsilon value for the layernorm/rmsnorm layer. Default: 1e-5.
55
+ gate_logit_normalizer (int, Optional):
56
+ The normalizer for the gate logits, appied after `logsigmoid`. Default: 16.
57
+ fuse_norm (bool, Optional):
58
+ Whether to fuse the norm and the output gate for better memory footprint. Default: `True`.
59
+ layer_idx (int, Optional):
60
+ The index of the layer. Default: None.
61
+ """
62
+
63
+ def __init__(
64
+ self,
65
+ mode: str = 'chunk',
66
+ hidden_size: int = 1024,
67
+ expand_k: float = 1.,
68
+ expand_v: float = 1.,
69
+ num_heads: int = 4,
70
+ num_kv_heads: Optional[int] = None,
71
+ feature_map: Optional[str] = None,
72
+ use_short_conv: bool = True,
73
+ conv_size: int = 4,
74
+ conv_bias: bool = False,
75
+ gate_fn: str = 'swish',
76
+ elementwise_affine: Optional[bool] = True,
77
+ norm_eps: float = 1e-5,
78
+ gate_logit_normalizer: int = 16,
79
+ fuse_norm: bool = True,
80
+ layer_idx: int = None,
81
+ ) -> SimpleGatedLinearAttention:
82
+ super().__init__()
83
+
84
+ self.mode = mode
85
+ self.hidden_size = hidden_size
86
+ self.expand_k = expand_k
87
+ self.expand_v = expand_v
88
+ self.num_heads = num_heads
89
+ self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
90
+ self.num_kv_groups = self.num_heads // self.num_kv_heads
91
+ self.feature_map_fn = ACT2FN[feature_map] if feature_map is not None else None
92
+
93
+ self.use_short_conv = use_short_conv
94
+ self.conv_size = conv_size
95
+ self.conv_bias = conv_bias
96
+
97
+ self.key_dim = int(hidden_size * expand_k)
98
+ self.value_dim = int(hidden_size * expand_v)
99
+ self.key_dim_per_group = self.key_dim // self.num_kv_groups
100
+ self.value_dim_per_group = self.value_dim // self.num_kv_groups
101
+ self.layer_idx = layer_idx
102
+
103
+ assert mode in ['chunk', "fused_recurrent"], f"Not suppoerted mode `{mode}`."
104
+ assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}"
105
+ assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}"
106
+
107
+ self.head_k_dim = self.key_dim // num_heads
108
+ self.head_v_dim = self.value_dim // num_heads
109
+
110
+ self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
111
+ self.k_proj = nn.Linear(hidden_size, self.key_dim_per_group, bias=False)
112
+ self.v_proj = nn.Linear(hidden_size, self.value_dim_per_group, bias=False)
113
+ self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
114
+
115
+ if use_short_conv:
116
+ self.conv_size = conv_size
117
+ self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu')
118
+ self.k_conv1d = ShortConvolution(self.key_dim_per_group, conv_size, activation='silu')
119
+ self.v_conv1d = ShortConvolution(self.value_dim_per_group, conv_size, activation='silu')
120
+
121
+ self.gk_proj = nn.Linear(hidden_size, self.num_heads)
122
+
123
+ if gate_fn == 'swish' and fuse_norm:
124
+ self.g_norm_swish_gate = FusedRMSNormGated(
125
+ hidden_size=self.head_v_dim,
126
+ elementwise_affine=elementwise_affine,
127
+ eps=norm_eps
128
+ )
129
+ self.fuse_norm_and_gate = True
130
+ else:
131
+ self.fuse_norm_and_gate = False
132
+ self.g_norm = RMSNorm(
133
+ hidden_size=self.head_v_dim,
134
+ elementwise_affine=elementwise_affine,
135
+ eps=norm_eps
136
+ )
137
+ self.gate_fn = ACT2FN[gate_fn]
138
+ self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
139
+
140
+ self.gate_logit_normalizer = gate_logit_normalizer
141
+
142
+ def forward(
143
+ self,
144
+ hidden_states: torch.Tensor,
145
+ attention_mask: Optional[torch.Tensor] = None,
146
+ past_key_values: Optional[Cache] = None,
147
+ use_cache: Optional[bool] = False,
148
+ output_attentions: Optional[bool] = False,
149
+ **kwargs
150
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
151
+ if attention_mask is not None:
152
+ assert len(attention_mask.shape) == 2, (
153
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
154
+ "for padding purposes (0 indicating padding). "
155
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
156
+ )
157
+
158
+ # launching the triton kernel for just one token will actually be slower
159
+ mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
160
+
161
+ last_state = None
162
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
163
+ last_state = past_key_values[self.layer_idx]
164
+
165
+ cu_seqlens = kwargs.get('cu_seqlens', None)
166
+ if self.use_short_conv:
167
+ conv_state_q, conv_state_k, conv_state_v = None, None, None
168
+ if last_state is not None:
169
+ conv_state_q, conv_state_k, conv_state_v = last_state['conv_state']
170
+ conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
171
+ q, conv_state_q = self.q_conv1d(
172
+ x=self.q_proj(hidden_states),
173
+ mask=conv_mask,
174
+ cache=conv_state_q,
175
+ output_final_state=use_cache,
176
+ cu_seqlens=cu_seqlens
177
+ )
178
+ k, conv_state_k = self.k_conv1d(
179
+ x=self.k_proj(hidden_states),
180
+ mask=conv_mask,
181
+ cache=conv_state_k,
182
+ output_final_state=use_cache,
183
+ cu_seqlens=cu_seqlens
184
+ )
185
+ v, conv_state_v = self.v_conv1d(
186
+ x=self.v_proj(hidden_states),
187
+ mask=conv_mask,
188
+ cache=conv_state_v,
189
+ output_final_state=use_cache,
190
+ cu_seqlens=cu_seqlens
191
+ )
192
+ else:
193
+ q = self.q_proj(hidden_states)
194
+ k = self.k_proj(hidden_states)
195
+ v = self.v_proj(hidden_states)
196
+ gk = self.gk_proj(hidden_states)
197
+
198
+ if self.feature_map_fn is not None:
199
+ q, k = map(self.feature_map_fn, (q, k))
200
+ # dealing with left-padding
201
+ if attention_mask is not None:
202
+ v = v.mul_(attention_mask[:, -v.shape[-2]:, None])
203
+ q = rearrange(q, '... (h d) -> ... h d', h=self.num_heads)
204
+ if self.num_kv_groups > 1:
205
+ k, v = (repeat(x, '... (h d) -> ... (h g) d', h=self.num_kv_heads, g=self.num_kv_groups) for x in (k, v))
206
+ else:
207
+ k, v = (rearrange(x, '... (h d) -> ... h d', h=self.num_kv_heads) for x in (k, v))
208
+ gk = F.logsigmoid(gk) / self.gate_logit_normalizer
209
+
210
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
211
+ if mode == 'chunk':
212
+ o, recurrent_state = chunk_simple_gla(
213
+ q=q,
214
+ k=k,
215
+ v=v,
216
+ gk=gk,
217
+ initial_state=recurrent_state,
218
+ output_final_state=use_cache,
219
+ cu_seqlens=cu_seqlens,
220
+ head_first=False
221
+ )
222
+ elif mode == 'fused_recurrent':
223
+ o, recurrent_state = fused_recurrent_simple_gla(
224
+ q=q,
225
+ k=k,
226
+ v=v,
227
+ gk=gk,
228
+ initial_state=recurrent_state,
229
+ output_final_state=use_cache,
230
+ cu_seqlens=cu_seqlens,
231
+ head_first=False
232
+ )
233
+ else:
234
+ raise NotImplementedError(f"Not supported mode `{mode}`.")
235
+
236
+ if past_key_values is not None:
237
+ past_key_values.update(
238
+ recurrent_state=recurrent_state,
239
+ conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
240
+ layer_idx=self.layer_idx,
241
+ offset=q.shape[1]
242
+ )
243
+
244
+ g = self.g_proj(hidden_states)
245
+ if self.fuse_norm_and_gate:
246
+ g = rearrange(g, 'b t (h d) -> b t h d', h=self.num_heads)
247
+ o = self.g_norm_swish_gate(o, g)
248
+ o = rearrange(o, 'b t h d -> b t (h d)')
249
+ else:
250
+ o = rearrange(self.g_norm(o), 'b t h d -> b t (h d)')
251
+ o = o * self.gate_fn(g)
252
+ o = self.o_proj(o)
253
+
254
+ return o, None, past_key_values
255
+
256
+ def state_size(self, **kwargs) -> int:
257
+ state_size = self.key_dim * self.head_v_dim
258
+ for module in self.children():
259
+ if isinstance(module, ShortConvolution):
260
+ state_size += module.state_size
261
+ return state_size
fla/models/__init__.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from fla.models.abc import ABCConfig, ABCForCausalLM, ABCModel
4
+ from fla.models.bitnet import BitNetConfig, BitNetForCausalLM, BitNetModel
5
+ from fla.models.delta_net import DeltaNetConfig, DeltaNetForCausalLM, DeltaNetModel
6
+ from fla.models.forgetting_transformer import (
7
+ ForgettingTransformerConfig,
8
+ ForgettingTransformerForCausalLM,
9
+ ForgettingTransformerModel
10
+ )
11
+ from fla.models.gated_deltanet import GatedDeltaNetConfig, GatedDeltaNetForCausalLM, GatedDeltaNetModel
12
+ from fla.models.gated_deltaproduct import GatedDeltaProductConfig, GatedDeltaProductForCausalLM, GatedDeltaProductModel
13
+ from fla.models.gla import GLAConfig, GLAForCausalLM, GLAModel
14
+ from fla.models.gsa import GSAConfig, GSAForCausalLM, GSAModel
15
+ from fla.models.hgrn import HGRNConfig, HGRNForCausalLM, HGRNModel
16
+ from fla.models.hgrn2 import HGRN2Config, HGRN2ForCausalLM, HGRN2Model
17
+ from fla.models.lightnet import LightNetConfig, LightNetForCausalLM, LightNetModel
18
+ from fla.models.linear_attn import LinearAttentionConfig, LinearAttentionForCausalLM, LinearAttentionModel
19
+ from fla.models.mamba import MambaConfig, MambaForCausalLM, MambaModel
20
+ from fla.models.mamba2 import Mamba2Config, Mamba2ForCausalLM, Mamba2Model
21
+ from fla.models.nsa import NSAConfig, NSAForCausalLM, NSAModel
22
+ from fla.models.retnet import RetNetConfig, RetNetForCausalLM, RetNetModel
23
+ from fla.models.rwkv6 import RWKV6Config, RWKV6ForCausalLM, RWKV6Model
24
+ from fla.models.rwkv7 import RWKV7Config, RWKV7ForCausalLM, RWKV7Model
25
+ from fla.models.samba import SambaConfig, SambaForCausalLM, SambaModel
26
+ from fla.models.transformer import TransformerConfig, TransformerForCausalLM, TransformerModel
27
+ from fla.models.transformer_top import TOPTransformerConfig, TOPTransformerForCausalLM, TOPTransformerModel
28
+ from fla.models.transformer_mtp import MTPTransformerConfig, MTPTransformerForCausalLM, MTPTransformerModel
29
+
30
+ __all__ = [
31
+ 'ABCConfig', 'ABCForCausalLM', 'ABCModel',
32
+ 'BitNetConfig', 'BitNetForCausalLM', 'BitNetModel',
33
+ 'DeltaNetConfig', 'DeltaNetForCausalLM', 'DeltaNetModel',
34
+ 'ForgettingTransformerConfig', 'ForgettingTransformerForCausalLM', 'ForgettingTransformerModel',
35
+ 'GatedDeltaNetConfig', 'GatedDeltaNetForCausalLM', 'GatedDeltaNetModel',
36
+ 'GLAConfig', 'GLAForCausalLM', 'GLAModel',
37
+ 'GSAConfig', 'GSAForCausalLM', 'GSAModel',
38
+ 'HGRNConfig', 'HGRNForCausalLM', 'HGRNModel',
39
+ 'HGRN2Config', 'HGRN2ForCausalLM', 'HGRN2Model',
40
+ 'LightNetConfig', 'LightNetForCausalLM', 'LightNetModel',
41
+ 'LinearAttentionConfig', 'LinearAttentionForCausalLM', 'LinearAttentionModel',
42
+ 'MambaConfig', 'MambaForCausalLM', 'MambaModel',
43
+ 'Mamba2Config', 'Mamba2ForCausalLM', 'Mamba2Model',
44
+ 'NSAConfig', 'NSAForCausalLM', 'NSAModel',
45
+ 'RetNetConfig', 'RetNetForCausalLM', 'RetNetModel',
46
+ 'RWKV6Config', 'RWKV6ForCausalLM', 'RWKV6Model',
47
+ 'RWKV7Config', 'RWKV7ForCausalLM', 'RWKV7Model',
48
+ 'SambaConfig', 'SambaForCausalLM', 'SambaModel',
49
+ 'TransformerConfig', 'TransformerForCausalLM', 'TransformerModel',
50
+ 'TOPTransformerConfig', 'TOPTransformerForCausalLM', 'TOPTransformerModel',
51
+ 'MTPTransformerConfig', 'MTPTransformerForCausalLM', 'MTPTransformerModel',
52
+ 'GatedDeltaProductConfig', 'GatedDeltaProductForCausalLM', 'GatedDeltaProductModel',
53
+ ]
fla/models/abc/__pycache__/configuration_abc.cpython-312.pyc ADDED
Binary file (3.61 kB). View file
 
fla/models/bitnet/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.bitnet.configuration_bitnet import BitNetConfig
6
+ from fla.models.bitnet.modeling_bitnet import BitNetForCausalLM, BitNetModel
7
+
8
+ AutoConfig.register(BitNetConfig.model_type, BitNetConfig)
9
+ AutoModel.register(BitNetConfig, BitNetModel)
10
+ AutoModelForCausalLM.register(BitNetConfig, BitNetForCausalLM)
11
+
12
+
13
+ __all__ = ['BitNetConfig', 'BitNetForCausalLM', 'BitNetModel']
fla/models/delta_net/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.delta_net.configuration_delta_net import DeltaNetConfig
6
+ from fla.models.delta_net.modeling_delta_net import DeltaNetForCausalLM, DeltaNetModel
7
+
8
+ AutoConfig.register(DeltaNetConfig.model_type, DeltaNetConfig)
9
+ AutoModel.register(DeltaNetConfig, DeltaNetModel)
10
+ AutoModelForCausalLM.register(DeltaNetConfig, DeltaNetForCausalLM)
11
+
12
+ __all__ = ['DeltaNetConfig', 'DeltaNetForCausalLM', 'DeltaNetModel']
fla/models/delta_net/__pycache__/modeling_delta_net.cpython-312.pyc ADDED
Binary file (18.5 kB). View file
 
fla/models/delta_net/modeling_delta_net.py ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from transformers.generation import GenerationMixin
13
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.utils import logging
16
+ from transformers.utils.deprecation import deprecate_kwarg
17
+
18
+ from fla.layers.attn import Attention
19
+ from fla.layers.delta_net import DeltaNet
20
+ from fla.models.delta_net.configuration_delta_net import DeltaNetConfig
21
+ from fla.models.utils import Cache
22
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss
23
+ from fla.modules import GatedMLP as DeltaNetMLP
24
+ from fla.modules import RMSNorm
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+ if TYPE_CHECKING:
29
+ from transformers.processing_utils import Unpack
30
+
31
+
32
+ class DeltaNetBlock(nn.Module):
33
+ def __init__(self, config: DeltaNetConfig, layer_idx: int):
34
+ super().__init__()
35
+
36
+ self.config = config
37
+ self.layer_idx = layer_idx
38
+
39
+ self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
40
+ if config.attn is not None and layer_idx in config.attn['layers']:
41
+ self.attn = Attention(
42
+ hidden_size=config.hidden_size,
43
+ num_heads=config.attn['num_heads'],
44
+ num_kv_heads=config.attn['num_kv_heads'],
45
+ qkv_bias=config.attn['qkv_bias'],
46
+ window_size=config.attn['window_size'],
47
+ rope_theta=config.attn['rope_theta'],
48
+ max_position_embeddings=config.max_position_embeddings,
49
+ layer_idx=layer_idx
50
+ )
51
+ else:
52
+ self.attn = DeltaNet(
53
+ mode=config.attn_mode,
54
+ hidden_size=config.hidden_size,
55
+ expand_k=config.expand_k,
56
+ expand_v=config.expand_v,
57
+ num_heads=config.num_heads,
58
+ use_gate=config.use_gate,
59
+ use_beta=config.use_beta,
60
+ use_short_conv=config.use_short_conv,
61
+ use_output_norm=config.use_output_norm,
62
+ conv_size=config.conv_size,
63
+ qk_norm=config.qk_norm,
64
+ qk_activation=config.qk_activation,
65
+ norm_eps=config.norm_eps,
66
+ layer_idx=layer_idx
67
+ )
68
+ self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
69
+ self.mlp = DeltaNetMLP(
70
+ hidden_size=config.hidden_size,
71
+ hidden_ratio=config.hidden_ratio,
72
+ intermediate_size=config.intermediate_size,
73
+ hidden_act=config.hidden_act,
74
+ fuse_swiglu=config.fuse_swiglu
75
+ )
76
+
77
+ def forward(
78
+ self,
79
+ hidden_states: torch.Tensor,
80
+ attention_mask: Optional[torch.Tensor] = None,
81
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
82
+ use_cache: Optional[bool] = False,
83
+ output_attentions: Optional[bool] = False,
84
+ **kwargs: Unpack[Dict]
85
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
86
+ residual = hidden_states
87
+ hidden_states = self.attn_norm(hidden_states)
88
+ hidden_states, attentions, past_key_values = self.attn(
89
+ hidden_states=hidden_states,
90
+ attention_mask=attention_mask,
91
+ past_key_values=past_key_values,
92
+ use_cache=use_cache,
93
+ output_attentions=output_attentions,
94
+ **kwargs
95
+ )
96
+ if self.config.fuse_norm:
97
+ hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
98
+ else:
99
+ hidden_states = residual + hidden_states
100
+ residual = hidden_states
101
+ hidden_states = self.mlp_norm(hidden_states)
102
+ hidden_states = self.mlp(hidden_states, **kwargs)
103
+ hidden_states = residual + hidden_states
104
+
105
+ outputs = (hidden_states, attentions, past_key_values)
106
+
107
+ return outputs
108
+
109
+
110
+ class DeltaNetPreTrainedModel(PreTrainedModel):
111
+
112
+ config_class = DeltaNetConfig
113
+ base_model_prefix = 'model'
114
+ supports_gradient_checkpointing = True
115
+ _no_split_modules = ['DeltaNetBlock']
116
+ _supports_cache_class = True
117
+
118
+ def __init__(self, *inputs, **kwargs):
119
+ super().__init__(*inputs, **kwargs)
120
+
121
+ def _init_weights(
122
+ self,
123
+ module: nn.Module,
124
+ prenorm_residual_strategy: Optional[str] = 'rescale',
125
+ num_residuals_per_layer: int = 2,
126
+ ):
127
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
128
+ # Slightly different from the TF version which uses truncated_normal for initialization
129
+ # cf https://github.com/pytorch/pytorch/pull/5617
130
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
131
+ if module.bias is not None:
132
+ nn.init.zeros_(module.bias)
133
+ elif isinstance(module, nn.Embedding):
134
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
135
+ elif hasattr(module, 'reset_parameters'):
136
+ module.reset_parameters()
137
+
138
+ if prenorm_residual_strategy is not None:
139
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
140
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
141
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
142
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
143
+ #
144
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
145
+ p = None
146
+ if hasattr(module, 'o_proj'):
147
+ p = module.o_proj.weight
148
+ elif hasattr(module, 'down_proj'):
149
+ p = module.down_proj.weight
150
+ if p is not None:
151
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
152
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
153
+ # We need to reinit p since this code could be called multiple times
154
+ # Having just p *= scale would repeatedly scale it down
155
+ if prenorm_residual_strategy == 'rescale':
156
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
157
+ with torch.no_grad():
158
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
159
+ elif prenorm_residual_strategy == 'zero':
160
+ nn.init.zeros_(p)
161
+ else:
162
+ raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}")
163
+
164
+
165
+ class DeltaNetModel(DeltaNetPreTrainedModel):
166
+
167
+ def __init__(self, config: DeltaNetConfig):
168
+ super().__init__(config)
169
+ self.padding_idx = config.pad_token_id
170
+ self.vocab_size = config.vocab_size
171
+
172
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
173
+ self.layers = nn.ModuleList([DeltaNetBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
174
+ self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
175
+
176
+ self.gradient_checkpointing = False
177
+
178
+ self.post_init()
179
+
180
+ def get_input_embeddings(self):
181
+ return self.embeddings
182
+
183
+ def set_input_embeddings(self, value):
184
+ self.embeddings = value
185
+
186
+ def forward(
187
+ self,
188
+ input_ids: Optional[torch.LongTensor] = None,
189
+ attention_mask: Optional[torch.Tensor] = None, # noqa
190
+ inputs_embeds: Optional[torch.FloatTensor] = None,
191
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
192
+ use_cache: Optional[bool] = None,
193
+ output_attentions: Optional[bool] = None,
194
+ output_hidden_states: Optional[bool] = None,
195
+ return_dict: Optional[bool] = None,
196
+ **kwargs: Unpack[Dict]
197
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
198
+ if output_attentions:
199
+ warnings.warn("`DeltaNetModel` does not `output_attentions` now, setting it to `False`.")
200
+ output_attentions = False
201
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
202
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
203
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
204
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
205
+
206
+ # retrieve input_ids and inputs_embeds
207
+ if input_ids is not None and inputs_embeds is not None:
208
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
209
+ if input_ids is None and inputs_embeds is None:
210
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
211
+
212
+ if inputs_embeds is None:
213
+ inputs_embeds = self.embeddings(input_ids)
214
+ hidden_states = inputs_embeds
215
+
216
+ if use_cache and not isinstance(past_key_values, Cache):
217
+ past_key_values = Cache.from_legacy_cache(past_key_values)
218
+
219
+ if self.gradient_checkpointing and self.training and use_cache:
220
+ logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
221
+ use_cache = False
222
+
223
+ all_hidden_states = () if output_hidden_states else None
224
+ all_attns = () if output_attentions else None
225
+ for layer in self.layers:
226
+ if output_hidden_states:
227
+ all_hidden_states += (hidden_states,)
228
+
229
+ if self.gradient_checkpointing and self.training:
230
+ hidden_states, attentions, past_key_values = self._gradient_checkpointing_func(
231
+ layer.__call__,
232
+ hidden_states,
233
+ attention_mask,
234
+ past_key_values,
235
+ use_cache,
236
+ output_attentions,
237
+ **kwargs
238
+ )
239
+ else:
240
+ hidden_states, attentions, past_key_values = layer(
241
+ hidden_states,
242
+ attention_mask=attention_mask,
243
+ past_key_values=past_key_values,
244
+ use_cache=use_cache,
245
+ output_attentions=output_attentions,
246
+ **kwargs
247
+ )
248
+
249
+ if output_attentions:
250
+ all_attns += (attentions,)
251
+
252
+ hidden_states = self.norm(hidden_states)
253
+
254
+ # add hidden states from the last decoder layer
255
+ if output_hidden_states:
256
+ all_hidden_states += (hidden_states,)
257
+
258
+ if not return_dict:
259
+ return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None)
260
+ return BaseModelOutputWithPast(
261
+ last_hidden_state=hidden_states,
262
+ past_key_values=past_key_values,
263
+ hidden_states=all_hidden_states,
264
+ attentions=all_attns
265
+ )
266
+
267
+
268
+ class DeltaNetForCausalLM(DeltaNetPreTrainedModel, GenerationMixin):
269
+
270
+ _tied_weights_keys = ["lm_head.weight"]
271
+
272
+ def __init__(self, config):
273
+ super().__init__(config)
274
+ self.model = DeltaNetModel(config)
275
+ self.vocab_size = config.vocab_size
276
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
277
+ self.criterion = None
278
+
279
+ # Initialize weights and apply final processing
280
+ self.post_init()
281
+
282
+ def get_input_embeddings(self):
283
+ return self.model.embeddings
284
+
285
+ def set_input_embeddings(self, value):
286
+ self.model.embeddings = value
287
+
288
+ def get_output_embeddings(self):
289
+ return self.lm_head
290
+
291
+ def set_output_embeddings(self, new_embeddings):
292
+ self.lm_head = new_embeddings
293
+
294
+ def set_decoder(self, decoder):
295
+ self.model = decoder
296
+
297
+ def get_decoder(self):
298
+ return self.model
299
+
300
+ def generate(self, *args, **kwargs):
301
+ try:
302
+ return super().generate(*args, **kwargs)
303
+ except AttributeError as exception:
304
+ if 'past_key_values' in str(exception):
305
+ raise AttributeError(
306
+ f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
307
+ f"which is not supported for {self.__class__.__name__}. "
308
+ f"Try another generation strategy instead. "
309
+ f"For the available generation strategies, check this doc: "
310
+ f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
311
+ )
312
+ else:
313
+ raise exception
314
+
315
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
316
+ def prepare_inputs_for_generation(
317
+ self,
318
+ input_ids: torch.LongTensor = None,
319
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
320
+ attention_mask: Optional[torch.Tensor] = None,
321
+ inputs_embeds: Optional[torch.Tensor] = None,
322
+ use_cache: bool = True,
323
+ logits_to_keep: Optional[int] = None,
324
+ **kwargs
325
+ ):
326
+ # only last token for `inputs_ids` if the `past_key_values` is not empty.
327
+ if past_key_values is not None and len(past_key_values) > 0:
328
+ input_ids = input_ids[:, -1:]
329
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
330
+ if inputs_embeds is not None and len(past_key_values) == 0:
331
+ model_inputs = {'inputs_embeds': inputs_embeds}
332
+ else:
333
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
334
+ # recompiles graphs as the stride of the inputs is a guard.
335
+ # Ref: https://github.com/huggingface/transformers/pull/29114
336
+ # TODO: use `next_tokens` directly instead.
337
+ model_inputs = {'input_ids': input_ids.contiguous()}
338
+
339
+ if logits_to_keep is not None:
340
+ model_inputs['logits_to_keep'] = logits_to_keep
341
+
342
+ model_inputs.update({
343
+ 'past_key_values': past_key_values,
344
+ 'use_cache': use_cache,
345
+ 'attention_mask': attention_mask,
346
+ })
347
+ return model_inputs
348
+
349
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
350
+ def forward(
351
+ self,
352
+ input_ids: torch.LongTensor = None,
353
+ attention_mask: Optional[torch.Tensor] = None,
354
+ inputs_embeds: Optional[torch.Tensor] = None,
355
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
356
+ labels: Optional[torch.LongTensor] = None,
357
+ use_cache: Optional[bool] = None,
358
+ output_attentions: Optional[bool] = None,
359
+ output_hidden_states: Optional[bool] = None,
360
+ return_dict: Optional[bool] = None,
361
+ logits_to_keep: Optional[int] = 0,
362
+ **kwargs: Unpack[Dict]
363
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
364
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
365
+ output_hidden_states = (
366
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
367
+ )
368
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
369
+
370
+ outputs = self.model(
371
+ input_ids=input_ids,
372
+ attention_mask=attention_mask,
373
+ inputs_embeds=inputs_embeds,
374
+ past_key_values=past_key_values,
375
+ use_cache=use_cache,
376
+ output_attentions=output_attentions,
377
+ output_hidden_states=output_hidden_states,
378
+ return_dict=return_dict,
379
+ **kwargs
380
+ )
381
+
382
+ hidden_states = outputs[0]
383
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
384
+
385
+ loss, logits = None, None
386
+ if not fuse_linear_and_cross_entropy or labels is None:
387
+ logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
388
+ if labels is not None:
389
+ if getattr(self, 'criterion', None) is None:
390
+ if fuse_linear_and_cross_entropy:
391
+ criterion = FusedLinearCrossEntropyLoss()
392
+ elif self.config.fuse_cross_entropy:
393
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
394
+ else:
395
+ criterion = nn.CrossEntropyLoss()
396
+ else:
397
+ criterion = self.criterion
398
+ labels = labels.to(hidden_states.device)
399
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
400
+ if fuse_linear_and_cross_entropy:
401
+ loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
402
+ else:
403
+ loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
404
+
405
+ if not return_dict:
406
+ output = (logits,) + outputs[1:]
407
+ return (loss,) + output if loss is not None else output
408
+
409
+ return CausalLMOutputWithPast(
410
+ loss=loss,
411
+ logits=logits,
412
+ past_key_values=outputs.past_key_values,
413
+ hidden_states=outputs.hidden_states,
414
+ attentions=outputs.attentions,
415
+ )
fla/models/forgetting_transformer/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (817 Bytes). View file
 
fla/models/forgetting_transformer/__pycache__/modeling_forgetting_transformer.cpython-312.pyc ADDED
Binary file (17.2 kB). View file
 
fla/models/gated_deltanet/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.gated_deltanet.configuration_gated_deltanet import GatedDeltaNetConfig
6
+ from fla.models.gated_deltanet.modeling_gated_deltanet import GatedDeltaNetForCausalLM, GatedDeltaNetModel
7
+
8
+ AutoConfig.register(GatedDeltaNetConfig.model_type, GatedDeltaNetConfig)
9
+ AutoModel.register(GatedDeltaNetConfig, GatedDeltaNetModel)
10
+ AutoModelForCausalLM.register(GatedDeltaNetConfig, GatedDeltaNetForCausalLM)
11
+
12
+ __all__ = ['GatedDeltaNetConfig', 'GatedDeltaNetForCausalLM', 'GatedDeltaNetModel']
fla/models/gated_deltanet/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (746 Bytes). View file
 
fla/models/gated_deltanet/__pycache__/configuration_gated_deltanet.cpython-312.pyc ADDED
Binary file (3.34 kB). View file
 
fla/models/gated_deltanet/__pycache__/modeling_gated_deltanet.cpython-312.pyc ADDED
Binary file (18.5 kB). View file
 
fla/models/gla/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (657 Bytes). View file
 
fla/models/gsa/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.gsa.configuration_gsa import GSAConfig
6
+ from fla.models.gsa.modeling_gsa import GSAForCausalLM, GSAModel
7
+
8
+ AutoConfig.register(GSAConfig.model_type, GSAConfig)
9
+ AutoModel.register(GSAConfig, GSAModel)
10
+ AutoModelForCausalLM.register(GSAConfig, GSAForCausalLM)
11
+
12
+
13
+ __all__ = ['GSAConfig', 'GSAForCausalLM', 'GSAModel']
fla/models/gsa/modeling_gsa.py ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from transformers.generation import GenerationMixin
13
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.utils import logging
16
+ from transformers.utils.deprecation import deprecate_kwarg
17
+
18
+ from fla.layers.attn import Attention
19
+ from fla.layers.gsa import GatedSlotAttention
20
+ from fla.models.gsa.configuration_gsa import GSAConfig
21
+ from fla.models.utils import Cache
22
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss
23
+ from fla.modules import GatedMLP as GSAMLP
24
+ from fla.modules import RMSNorm
25
+
26
+ if TYPE_CHECKING:
27
+ from transformers.processing_utils import Unpack
28
+
29
+ logger = logging.get_logger(__name__)
30
+
31
+
32
+ class GSABlock(nn.Module):
33
+ def __init__(self, config: GSAConfig, layer_idx: int):
34
+ super().__init__()
35
+
36
+ self.config = config
37
+ self.layer_idx = layer_idx
38
+
39
+ self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
40
+ if config.attn is not None and layer_idx in config.attn['layers']:
41
+ self.attn = Attention(
42
+ hidden_size=config.hidden_size,
43
+ num_heads=config.attn['num_heads'],
44
+ num_kv_heads=config.attn['num_kv_heads'],
45
+ qkv_bias=config.attn['qkv_bias'],
46
+ window_size=config.attn['window_size'],
47
+ rope_theta=config.attn['rope_theta'],
48
+ max_position_embeddings=config.max_position_embeddings,
49
+ layer_idx=layer_idx
50
+ )
51
+ else:
52
+ self.attn = GatedSlotAttention(
53
+ hidden_size=config.hidden_size,
54
+ expand_k=config.expand_k,
55
+ expand_v=config.expand_v,
56
+ num_heads=config.num_heads,
57
+ num_kv_heads=config.num_kv_heads,
58
+ num_slots=config.num_slots,
59
+ use_short_conv=config.use_short_conv,
60
+ conv_size=config.conv_size,
61
+ feature_map=config.feature_map,
62
+ use_output_gate=config.use_output_gate,
63
+ use_norm=config.use_norm,
64
+ gate_fn=config.hidden_act,
65
+ gate_logit_normalizer=config.gate_logit_normalizer,
66
+ elementwise_affine=config.elementwise_affine,
67
+ norm_eps=config.norm_eps,
68
+ fuse_norm=config.fuse_norm,
69
+ layer_idx=layer_idx
70
+ )
71
+ self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
72
+ self.mlp = GSAMLP(
73
+ hidden_size=config.hidden_size,
74
+ hidden_ratio=config.hidden_ratio,
75
+ intermediate_size=config.intermediate_size,
76
+ hidden_act=config.hidden_act,
77
+ fuse_swiglu=config.fuse_swiglu
78
+ )
79
+
80
+ def forward(
81
+ self,
82
+ hidden_states: torch.Tensor,
83
+ attention_mask: Optional[torch.Tensor] = None,
84
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
85
+ use_cache: Optional[bool] = False,
86
+ output_attentions: Optional[bool] = False,
87
+ **kwargs: Unpack[Dict]
88
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
89
+ residual = hidden_states
90
+ hidden_states = self.attn_norm(hidden_states)
91
+ hidden_states, attentions, past_key_values = self.attn(
92
+ hidden_states=hidden_states,
93
+ attention_mask=attention_mask,
94
+ past_key_values=past_key_values,
95
+ use_cache=use_cache,
96
+ output_attentions=output_attentions,
97
+ **kwargs
98
+ )
99
+ if self.config.fuse_norm:
100
+ hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
101
+ else:
102
+ hidden_states = residual + hidden_states
103
+ residual = hidden_states
104
+ hidden_states = self.mlp_norm(hidden_states)
105
+ hidden_states = self.mlp(hidden_states, **kwargs)
106
+ hidden_states = residual + hidden_states
107
+
108
+ outputs = (hidden_states, attentions, past_key_values)
109
+
110
+ return outputs
111
+
112
+
113
+ class GSAPreTrainedModel(PreTrainedModel):
114
+
115
+ config_class = GSAConfig
116
+ base_model_prefix = 'model'
117
+ supports_gradient_checkpointing = True
118
+ _no_split_modules = ['GSABlock']
119
+ _supports_cache_class = True
120
+
121
+ def __init__(self, *inputs, **kwargs):
122
+ super().__init__(*inputs, **kwargs)
123
+
124
+ def _init_weights(
125
+ self,
126
+ module: nn.Module,
127
+ prenorm_residual_strategy: Optional[str] = 'rescale',
128
+ num_residuals_per_layer: int = 2,
129
+ ):
130
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
131
+ # Slightly different from the TF version which uses truncated_normal for initialization
132
+ # cf https://github.com/pytorch/pytorch/pull/5617
133
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
134
+ if module.bias is not None:
135
+ nn.init.zeros_(module.bias)
136
+ elif isinstance(module, nn.Embedding):
137
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
138
+ elif hasattr(module, 'reset_parameters'):
139
+ module.reset_parameters()
140
+
141
+ if prenorm_residual_strategy is not None:
142
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
143
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
144
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
145
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
146
+ #
147
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
148
+ p = None
149
+ if hasattr(module, 'o_proj'):
150
+ p = module.o_proj.weight
151
+ elif hasattr(module, 'down_proj'):
152
+ p = module.down_proj.weight
153
+ if p is not None:
154
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
155
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
156
+ # We need to reinit p since this code could be called multiple times
157
+ # Having just p *= scale would repeatedly scale it down
158
+ if prenorm_residual_strategy == 'rescale':
159
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
160
+ with torch.no_grad():
161
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
162
+ elif prenorm_residual_strategy == 'zero':
163
+ nn.init.zeros_(p)
164
+ else:
165
+ raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}")
166
+
167
+
168
+ class GSAModel(GSAPreTrainedModel):
169
+
170
+ def __init__(self, config: GSAConfig):
171
+ super().__init__(config)
172
+ self.padding_idx = config.pad_token_id
173
+ self.vocab_size = config.vocab_size
174
+
175
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
176
+ self.layers = nn.ModuleList([GSABlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
177
+ self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
178
+
179
+ self.gradient_checkpointing = False
180
+
181
+ self.post_init()
182
+
183
+ def get_input_embeddings(self):
184
+ return self.embeddings
185
+
186
+ def set_input_embeddings(self, value):
187
+ self.embeddings = value
188
+
189
+ def forward(
190
+ self,
191
+ input_ids: Optional[torch.LongTensor] = None,
192
+ attention_mask: Optional[torch.Tensor] = None, # noqa
193
+ inputs_embeds: Optional[torch.FloatTensor] = None,
194
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
195
+ use_cache: Optional[bool] = None,
196
+ output_attentions: Optional[bool] = None,
197
+ output_hidden_states: Optional[bool] = None,
198
+ return_dict: Optional[bool] = None,
199
+ **kwargs: Unpack[Dict]
200
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
201
+ if output_attentions:
202
+ warnings.warn("`GSAModel` does not `output_attentions` now, setting it to `False`.")
203
+ output_attentions = False
204
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
205
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
206
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
207
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
208
+
209
+ # retrieve input_ids and inputs_embeds
210
+ if input_ids is not None and inputs_embeds is not None:
211
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
212
+ if input_ids is None and inputs_embeds is None:
213
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
214
+
215
+ if inputs_embeds is None:
216
+ inputs_embeds = self.embeddings(input_ids)
217
+ hidden_states = inputs_embeds
218
+
219
+ if use_cache and not isinstance(past_key_values, Cache):
220
+ past_key_values = Cache.from_legacy_cache(past_key_values)
221
+
222
+ if self.gradient_checkpointing and self.training and use_cache:
223
+ logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
224
+ use_cache = False
225
+
226
+ all_hidden_states = () if output_hidden_states else None
227
+ all_attns = () if output_attentions else None
228
+ for layer in self.layers:
229
+ if output_hidden_states:
230
+ all_hidden_states += (hidden_states,)
231
+
232
+ if self.gradient_checkpointing and self.training:
233
+ hidden_states, attentions, past_key_values = self._gradient_checkpointing_func(
234
+ layer.__call__,
235
+ hidden_states,
236
+ attention_mask,
237
+ past_key_values,
238
+ use_cache,
239
+ output_attentions,
240
+ **kwargs
241
+ )
242
+ else:
243
+ hidden_states, attentions, past_key_values = layer(
244
+ hidden_states,
245
+ attention_mask=attention_mask,
246
+ past_key_values=past_key_values,
247
+ use_cache=use_cache,
248
+ output_attentions=output_attentions,
249
+ **kwargs
250
+ )
251
+
252
+ if output_attentions:
253
+ all_attns += (attentions,)
254
+
255
+ hidden_states = self.norm(hidden_states)
256
+
257
+ # add hidden states from the last decoder layer
258
+ if output_hidden_states:
259
+ all_hidden_states += (hidden_states,)
260
+
261
+ if not return_dict:
262
+ return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None)
263
+ return BaseModelOutputWithPast(
264
+ last_hidden_state=hidden_states,
265
+ past_key_values=past_key_values,
266
+ hidden_states=all_hidden_states,
267
+ attentions=all_attns
268
+ )
269
+
270
+
271
+ class GSAForCausalLM(GSAPreTrainedModel, GenerationMixin):
272
+
273
+ _tied_weights_keys = ["lm_head.weight"]
274
+
275
+ def __init__(self, config):
276
+
277
+ super().__init__(config)
278
+ self.model = GSAModel(config)
279
+ self.vocab_size = config.vocab_size
280
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
281
+ self.criterion = None
282
+
283
+ # Initialize weights and apply final processing
284
+ self.post_init()
285
+
286
+ def get_input_embeddings(self):
287
+ return self.model.embeddings
288
+
289
+ def set_input_embeddings(self, value):
290
+ self.model.embeddings = value
291
+
292
+ def get_output_embeddings(self):
293
+ return self.lm_head
294
+
295
+ def set_output_embeddings(self, new_embeddings):
296
+ self.lm_head = new_embeddings
297
+
298
+ def set_decoder(self, decoder):
299
+ self.model = decoder
300
+
301
+ def get_decoder(self):
302
+ return self.model
303
+
304
+ def generate(self, *args, **kwargs):
305
+ try:
306
+ return super().generate(*args, **kwargs)
307
+ except AttributeError as exception:
308
+ if 'past_key_values' in str(exception):
309
+ raise AttributeError(
310
+ f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
311
+ f"which is not supported for {self.__class__.__name__}. "
312
+ f"Try another generation strategy instead. "
313
+ f"For the available generation strategies, check this doc: "
314
+ f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
315
+ )
316
+ else:
317
+ raise exception
318
+
319
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
320
+ def prepare_inputs_for_generation(
321
+ self,
322
+ input_ids: torch.LongTensor = None,
323
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
324
+ attention_mask: Optional[torch.Tensor] = None,
325
+ inputs_embeds: Optional[torch.Tensor] = None,
326
+ use_cache: bool = True,
327
+ logits_to_keep: Optional[int] = None,
328
+ **kwargs
329
+ ):
330
+ # only last token for `inputs_ids` if the `past_key_values` is not empty.
331
+ if past_key_values is not None and len(past_key_values) > 0:
332
+ input_ids = input_ids[:, -1:]
333
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
334
+ if inputs_embeds is not None and len(past_key_values) == 0:
335
+ model_inputs = {'inputs_embeds': inputs_embeds}
336
+ else:
337
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
338
+ # recompiles graphs as the stride of the inputs is a guard.
339
+ # Ref: https://github.com/huggingface/transformers/pull/29114
340
+ # TODO: use `next_tokens` directly instead.
341
+ model_inputs = {'input_ids': input_ids.contiguous()}
342
+
343
+ if logits_to_keep is not None:
344
+ model_inputs['logits_to_keep'] = logits_to_keep
345
+
346
+ model_inputs.update({
347
+ 'past_key_values': past_key_values,
348
+ 'use_cache': use_cache,
349
+ 'attention_mask': attention_mask,
350
+ })
351
+ return model_inputs
352
+
353
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
354
+ def forward(
355
+ self,
356
+ input_ids: torch.LongTensor = None,
357
+ attention_mask: Optional[torch.Tensor] = None,
358
+ inputs_embeds: Optional[torch.Tensor] = None,
359
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
360
+ labels: Optional[torch.LongTensor] = None,
361
+ use_cache: Optional[bool] = None,
362
+ output_attentions: Optional[bool] = None,
363
+ output_hidden_states: Optional[bool] = None,
364
+ return_dict: Optional[bool] = None,
365
+ logits_to_keep: Optional[int] = 0,
366
+ **kwargs: Unpack[Dict]
367
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
368
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
369
+ output_hidden_states = (
370
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
371
+ )
372
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
373
+
374
+ outputs = self.model(
375
+ input_ids=input_ids,
376
+ attention_mask=attention_mask,
377
+ inputs_embeds=inputs_embeds,
378
+ past_key_values=past_key_values,
379
+ use_cache=use_cache,
380
+ output_attentions=output_attentions,
381
+ output_hidden_states=output_hidden_states,
382
+ return_dict=return_dict,
383
+ **kwargs
384
+ )
385
+
386
+ hidden_states = outputs[0]
387
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
388
+
389
+ loss, logits = None, None
390
+ if not fuse_linear_and_cross_entropy or labels is None:
391
+ logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
392
+ if labels is not None:
393
+ if getattr(self, 'criterion', None) is None:
394
+ if fuse_linear_and_cross_entropy:
395
+ criterion = FusedLinearCrossEntropyLoss()
396
+ elif self.config.fuse_cross_entropy:
397
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
398
+ else:
399
+ criterion = nn.CrossEntropyLoss()
400
+ else:
401
+ criterion = self.criterion
402
+ # Enable model parallelism
403
+ labels = labels.to(hidden_states.device)
404
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
405
+ if fuse_linear_and_cross_entropy:
406
+ loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
407
+ else:
408
+ loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
409
+
410
+ if not return_dict:
411
+ output = (logits,) + outputs[1:]
412
+ return (loss,) + output if loss is not None else output
413
+
414
+ return CausalLMOutputWithPast(
415
+ loss=loss,
416
+ logits=logits,
417
+ past_key_values=outputs.past_key_values,
418
+ hidden_states=outputs.hidden_states,
419
+ attentions=outputs.attentions,
420
+ )
fla/models/hgrn/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.hgrn.configuration_hgrn import HGRNConfig
6
+ from fla.models.hgrn.modeling_hgrn import HGRNForCausalLM, HGRNModel
7
+
8
+ AutoConfig.register(HGRNConfig.model_type, HGRNConfig)
9
+ AutoModel.register(HGRNConfig, HGRNModel)
10
+ AutoModelForCausalLM.register(HGRNConfig, HGRNForCausalLM)
11
+
12
+
13
+ __all__ = ['HGRNConfig', 'HGRNForCausalLM', 'HGRNModel']
fla/models/hgrn2/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (674 Bytes). View file
 
fla/models/lightnet/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (699 Bytes). View file
 
fla/models/lightnet/__pycache__/modeling_lightnet.cpython-312.pyc ADDED
Binary file (18.3 kB). View file
 
fla/models/lightnet/modeling_lightnet.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from transformers.generation import GenerationMixin
13
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.utils import logging
16
+ from transformers.utils.deprecation import deprecate_kwarg
17
+
18
+ from fla.layers.attn import Attention
19
+ from fla.layers.lightnet import LightNetAttention
20
+ from fla.models.lightnet.configuration_lightnet import LightNetConfig
21
+ from fla.models.utils import Cache
22
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss
23
+ from fla.modules import GatedMLP as LightNetMLP
24
+ from fla.modules import RMSNorm
25
+
26
+ if TYPE_CHECKING:
27
+ from transformers.processing_utils import Unpack
28
+
29
+ logger = logging.get_logger(__name__)
30
+
31
+
32
+ class LightNetBlock(nn.Module):
33
+ def __init__(self, config: LightNetConfig, layer_idx: int):
34
+ super().__init__()
35
+
36
+ self.config = config
37
+ self.layer_idx = layer_idx
38
+
39
+ self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
40
+ if config.attn is not None and layer_idx in config.attn['layers']:
41
+ self.attn = Attention(
42
+ hidden_size=config.hidden_size,
43
+ num_heads=config.attn['num_heads'],
44
+ num_kv_heads=config.attn['num_kv_heads'],
45
+ qkv_bias=config.attn['qkv_bias'],
46
+ window_size=config.attn['window_size'],
47
+ max_position_embeddings=config.max_position_embeddings,
48
+ layer_idx=layer_idx
49
+ )
50
+ else:
51
+ self.attn = LightNetAttention(
52
+ mode=config.attn_mode,
53
+ hidden_size=config.hidden_size,
54
+ num_heads=config.num_heads,
55
+ expand_ratio=config.expand_ratio,
56
+ use_short_conv=config.use_short_conv,
57
+ conv_size=config.conv_size,
58
+ gate_low_rank_dim=config.gate_low_rank_dim,
59
+ elementwise_affine=config.elementwise_affine,
60
+ norm_eps=config.norm_eps,
61
+ layer_idx=layer_idx
62
+ )
63
+ self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
64
+ self.mlp = LightNetMLP(
65
+ hidden_size=config.hidden_size,
66
+ hidden_ratio=config.hidden_ratio,
67
+ intermediate_size=config.intermediate_size,
68
+ hidden_act=config.hidden_act,
69
+ fuse_swiglu=config.fuse_swiglu
70
+ )
71
+
72
+ def forward(
73
+ self,
74
+ hidden_states: torch.Tensor,
75
+ attention_mask: Optional[torch.Tensor] = None,
76
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
77
+ use_cache: Optional[bool] = False,
78
+ output_attentions: Optional[bool] = False,
79
+ **kwargs: Unpack[Dict]
80
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
81
+ residual = hidden_states
82
+ hidden_states = self.attn_norm(hidden_states)
83
+ hidden_states, attentions, past_key_values = self.attn(
84
+ hidden_states=hidden_states,
85
+ attention_mask=attention_mask,
86
+ past_key_values=past_key_values,
87
+ use_cache=use_cache,
88
+ output_attentions=output_attentions,
89
+ **kwargs
90
+ )
91
+ if self.config.fuse_norm:
92
+ hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
93
+ else:
94
+ hidden_states = residual + hidden_states
95
+ residual = hidden_states
96
+ hidden_states = self.mlp_norm(hidden_states)
97
+ hidden_states = self.mlp(hidden_states, **kwargs)
98
+ hidden_states = residual + hidden_states
99
+
100
+ outputs = (hidden_states, attentions, past_key_values)
101
+
102
+ return outputs
103
+
104
+
105
+ class LightNetPreTrainedModel(PreTrainedModel):
106
+
107
+ config_class = LightNetConfig
108
+ supports_gradient_checkpointing = True
109
+ _no_split_modules = ['LightNetBlock']
110
+ _supports_cache_class = True
111
+
112
+ def __init__(self, *inputs, **kwargs):
113
+ super().__init__(*inputs, **kwargs)
114
+
115
+ def _init_weights(
116
+ self,
117
+ module: nn.Module,
118
+ prenorm_residual_strategy: Optional[str] = 'rescale',
119
+ num_residuals_per_layer: int = 2,
120
+ ):
121
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
122
+ # Slightly different from the TF version which uses truncated_normal for initialization
123
+ # cf https://github.com/pytorch/pytorch/pull/5617
124
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
125
+ if module.bias is not None:
126
+ nn.init.zeros_(module.bias)
127
+ elif isinstance(module, nn.Embedding):
128
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
129
+ elif hasattr(module, 'reset_parameters'):
130
+ module.reset_parameters()
131
+
132
+ if prenorm_residual_strategy is not None:
133
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
134
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
135
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
136
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
137
+ #
138
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
139
+ p = None
140
+ if hasattr(module, 'o_proj'):
141
+ p = module.o_proj.weight
142
+ elif hasattr(module, 'down_proj'):
143
+ p = module.down_proj.weight
144
+ if p is not None:
145
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
146
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
147
+ # We need to reinit p since this code could be called multiple times
148
+ # Having just p *= scale would repeatedly scale it down
149
+ if prenorm_residual_strategy == 'rescale':
150
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
151
+ with torch.no_grad():
152
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
153
+ elif prenorm_residual_strategy == 'zero':
154
+ nn.init.zeros_(p)
155
+ else:
156
+ raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}")
157
+
158
+
159
+ class LightNetModel(LightNetPreTrainedModel):
160
+
161
+ def __init__(self, config: LightNetConfig):
162
+ super().__init__(config)
163
+ self.padding_idx = config.pad_token_id
164
+ self.vocab_size = config.vocab_size
165
+
166
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
167
+ self.layers = nn.ModuleList([LightNetBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
168
+ self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
169
+
170
+ self.gradient_checkpointing = False
171
+
172
+ self.post_init()
173
+
174
+ def get_input_embeddings(self):
175
+ return self.embeddings
176
+
177
+ def set_input_embeddings(self, value):
178
+ self.embeddings = value
179
+
180
+ def forward(
181
+ self,
182
+ input_ids: Optional[torch.LongTensor] = None,
183
+ attention_mask: Optional[torch.Tensor] = None, # noqa
184
+ inputs_embeds: Optional[torch.FloatTensor] = None,
185
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
186
+ use_cache: Optional[bool] = None,
187
+ output_attentions: Optional[bool] = None,
188
+ output_hidden_states: Optional[bool] = None,
189
+ return_dict: Optional[bool] = None,
190
+ **kwargs: Unpack[Dict]
191
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
192
+ if output_attentions:
193
+ warnings.warn("`LightNetModel` does not `output_attentions` now, setting it to `False`.")
194
+ output_attentions = False
195
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
196
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
197
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
198
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
199
+
200
+ # retrieve input_ids and inputs_embeds
201
+ if input_ids is not None and inputs_embeds is not None:
202
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
203
+ if input_ids is None and inputs_embeds is None:
204
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
205
+
206
+ if inputs_embeds is None:
207
+ inputs_embeds = self.embeddings(input_ids)
208
+ hidden_states = inputs_embeds
209
+
210
+ if use_cache and not isinstance(past_key_values, Cache):
211
+ past_key_values = Cache.from_legacy_cache(past_key_values)
212
+
213
+ if self.gradient_checkpointing and self.training and use_cache:
214
+ logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
215
+ use_cache = False
216
+
217
+ all_hidden_states = () if output_hidden_states else None
218
+ all_attns = () if output_attentions else None
219
+
220
+ for i, layer in enumerate(self.layers):
221
+ if output_hidden_states:
222
+ all_hidden_states += (hidden_states,)
223
+
224
+ if self.gradient_checkpointing and self.training:
225
+ hidden_states, attentions, past_key_values = self._gradient_checkpointing_func(
226
+ layer.__call__,
227
+ hidden_states,
228
+ attention_mask,
229
+ past_key_values,
230
+ use_cache,
231
+ output_attentions,
232
+ **kwargs
233
+ )
234
+ else:
235
+ hidden_states, attentions, past_key_values = layer(
236
+ hidden_states,
237
+ attention_mask=attention_mask,
238
+ past_key_values=past_key_values,
239
+ use_cache=use_cache,
240
+ output_attentions=output_attentions,
241
+ **kwargs
242
+ )
243
+
244
+ if output_attentions:
245
+ all_attns += (attentions,)
246
+
247
+ hidden_states = self.norm(hidden_states)
248
+
249
+ # add hidden states from the last decoder layer
250
+ if output_hidden_states:
251
+ all_hidden_states += (hidden_states,)
252
+
253
+ if not return_dict:
254
+ return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None)
255
+ return BaseModelOutputWithPast(
256
+ last_hidden_state=hidden_states,
257
+ past_key_values=past_key_values,
258
+ hidden_states=all_hidden_states,
259
+ attentions=all_attns
260
+ )
261
+
262
+
263
+ class LightNetForCausalLM(LightNetPreTrainedModel, GenerationMixin):
264
+
265
+ _tied_weights_keys = ["lm_head.weight"]
266
+
267
+ def __init__(self, config):
268
+ super().__init__(config)
269
+ self.model = LightNetModel(config)
270
+ self.vocab_size = config.vocab_size
271
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
272
+ self.criterion = None
273
+
274
+ # Initialize weights and apply final processing
275
+ self.post_init()
276
+
277
+ def get_input_embeddings(self):
278
+ return self.model.embeddings
279
+
280
+ def set_input_embeddings(self, value):
281
+ self.model.embeddings = value
282
+
283
+ def get_output_embeddings(self):
284
+ return self.lm_head
285
+
286
+ def set_output_embeddings(self, new_embeddings):
287
+ self.lm_head = new_embeddings
288
+
289
+ def set_decoder(self, decoder):
290
+ self.model = decoder
291
+
292
+ def get_decoder(self):
293
+ return self.model
294
+
295
+ def generate(self, *args, **kwargs):
296
+ try:
297
+ return super().generate(*args, **kwargs)
298
+ except AttributeError as exception:
299
+ if 'past_key_values' in str(exception):
300
+ raise AttributeError(
301
+ f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
302
+ f"which is not supported for {self.__class__.__name__}. "
303
+ f"Try another generation strategy instead. "
304
+ f"For the available generation strategies, check this doc: "
305
+ f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
306
+ )
307
+ else:
308
+ raise exception
309
+
310
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
311
+ def prepare_inputs_for_generation(
312
+ self,
313
+ input_ids: torch.LongTensor = None,
314
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
315
+ attention_mask: Optional[torch.Tensor] = None,
316
+ inputs_embeds: Optional[torch.Tensor] = None,
317
+ use_cache: bool = True,
318
+ logits_to_keep: Optional[int] = None,
319
+ **kwargs: Unpack[Dict]
320
+ ):
321
+ # only last token for `inputs_ids` if the `past_key_values` is not empty.
322
+ if past_key_values is not None and len(past_key_values) > 0:
323
+ input_ids = input_ids[:, -1:]
324
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
325
+ if inputs_embeds is not None and len(past_key_values) == 0:
326
+ model_inputs = {'inputs_embeds': inputs_embeds}
327
+ else:
328
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
329
+ # recompiles graphs as the stride of the inputs is a guard.
330
+ # Ref: https://github.com/huggingface/transformers/pull/29114
331
+ # TODO: use `next_tokens` directly instead.
332
+ model_inputs = {'input_ids': input_ids.contiguous()}
333
+
334
+ if logits_to_keep is not None:
335
+ model_inputs['logits_to_keep'] = logits_to_keep
336
+
337
+ model_inputs.update({
338
+ 'past_key_values': past_key_values,
339
+ 'use_cache': use_cache,
340
+ 'attention_mask': attention_mask,
341
+ })
342
+ return model_inputs
343
+
344
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
345
+ def forward(
346
+ self,
347
+ input_ids: torch.LongTensor = None,
348
+ attention_mask: Optional[torch.Tensor] = None,
349
+ inputs_embeds: Optional[torch.Tensor] = None,
350
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
351
+ labels: Optional[torch.LongTensor] = None,
352
+ use_cache: Optional[bool] = None,
353
+ output_attentions: Optional[bool] = None,
354
+ output_hidden_states: Optional[bool] = None,
355
+ return_dict: Optional[bool] = None,
356
+ logits_to_keep: Optional[int] = 0,
357
+ **kwargs: Unpack[Dict]
358
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
359
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
360
+ output_hidden_states = (
361
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
362
+ )
363
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
364
+
365
+ outputs = self.model(
366
+ input_ids=input_ids,
367
+ attention_mask=attention_mask,
368
+ inputs_embeds=inputs_embeds,
369
+ past_key_values=past_key_values,
370
+ use_cache=use_cache,
371
+ output_attentions=output_attentions,
372
+ output_hidden_states=output_hidden_states,
373
+ return_dict=return_dict,
374
+ **kwargs
375
+ )
376
+
377
+ hidden_states = outputs[0]
378
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
379
+
380
+ loss, logits = None, None
381
+ if not fuse_linear_and_cross_entropy or labels is None:
382
+ logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
383
+ if labels is not None:
384
+ if getattr(self, 'criterion', None) is None:
385
+ if fuse_linear_and_cross_entropy:
386
+ criterion = FusedLinearCrossEntropyLoss()
387
+ elif self.config.fuse_cross_entropy:
388
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
389
+ else:
390
+ criterion = nn.CrossEntropyLoss()
391
+ else:
392
+ criterion = self.criterion
393
+ labels = labels.to(hidden_states.device)
394
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
395
+ if fuse_linear_and_cross_entropy:
396
+ loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
397
+ else:
398
+ loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
399
+
400
+ if not return_dict:
401
+ output = (logits,) + outputs[1:]
402
+ return (loss,) + output if loss is not None else output
403
+
404
+ return CausalLMOutputWithPast(
405
+ loss=loss,
406
+ logits=logits,
407
+ past_key_values=outputs.past_key_values,
408
+ hidden_states=outputs.hidden_states,
409
+ attentions=outputs.attentions,
410
+ )
fla/models/linear_attn/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (737 Bytes). View file
 
fla/models/linear_attn/__pycache__/configuration_linear_attn.cpython-312.pyc ADDED
Binary file (3.65 kB). View file
 
fla/models/mamba/configuration_mamba.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """MAMBA configuration"""
16
+
17
+ import math
18
+
19
+ from transformers.configuration_utils import PretrainedConfig
20
+
21
+
22
+ class MambaConfig(PretrainedConfig):
23
+ """
24
+ This is the configuration class to store the configuration of a [`MambaModel`]. It is used to instantiate a MAMBA
25
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
26
+ defaults will yield a similar configuration to that of the MAMBA
27
+ [state-spaces/mamba-2.8b](https://huggingface.co/state-spaces/mamba-2.8b) architecture.
28
+
29
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
30
+ documentation from [`PretrainedConfig`] for more information.
31
+
32
+
33
+ Args:
34
+ vocab_size (`int`, *optional*):
35
+ Vocabulary size of the Mamba model.
36
+ hidden_size (`int`, *optional*):
37
+ Dimensionality of the embeddings and hidden states. Default: 2048.
38
+ state_size (`int`, *optional*):
39
+ Shape of the state space latents. Default: 16.
40
+ num_hidden_layers (`int`, *optional*):
41
+ Number of hidden layers in the model. Default: 48.
42
+ layer_norm_epsilon (`float`, *optional*):
43
+ The epsilon to use in the layer normalization layers. Default: 1e-5.
44
+ pad_token_id (`int`, *optional*):
45
+ Padding token id. Default: 0.
46
+ bos_token_id (`int`, *optional*):
47
+ The id of the beginning of sentence token in the vocabulary. Default: 0.
48
+ eos_token_id (`int`, *optional*):
49
+ The id of the end of sentence token in the vocabulary. Default: 0.
50
+ expand (`int`, *optional*):
51
+ Expanding factor used to determine the intermediate size. Default: 2.
52
+ conv_kernel (`int`, *optional*):
53
+ Size of the convolution kernel. Default: 4.
54
+ use_bias (`bool`, *optional*):
55
+ Whether or not to use bias in ["in_proj", "out_proj"] of the mixer block. Default: `False`.
56
+ use_conv_bias (`bool`, *optional*):
57
+ Whether or not to use bias in the convolution layer of the mixer block. Default: `True`.
58
+ hidden_act (`str`, *optional*):
59
+ The non-linear activation function (function or string) in the decoder. Default: `"silu"`.
60
+ initializer_range (`float`, *optional*):
61
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices. Default: 0.1.
62
+ residual_in_fp32 (`bool`, *optional*):
63
+ Whether or not residuals should be in `float32`.
64
+ If set to `False` residuals will keep the same `dtype` as the rest of the model. Default: `True`.
65
+ time_step_rank (`Union[int,str]`, *optional*):
66
+ Rank of the the discretization projection matrix.
67
+ `"auto"` means that it will default to `math.ceil(self.hidden_size / 16)`. Default: `"auto"`.
68
+ time_step_scale (`float`, *optional*):
69
+ Scale used used to scale `dt_proj.bias`. Default: 1.0.
70
+ time_step_min (`float`, *optional*):
71
+ Minimum `time_step` used to bound `dt_proj.bias`. Default: 0.001.
72
+ time_step_max (`float`, *optional*):
73
+ Maximum `time_step` used to bound `dt_proj.bias`. Default: 0.1.
74
+ time_step_init_scheme (`float`, *optional*):
75
+ Init scheme used for `dt_proj.weight`. Should be one of `["random","uniform"]`. Default: `"random"`.
76
+ time_step_floor (`float`, *optional*):
77
+ Minimum clamping value of the `dt_proj.bias` layer initialization. Default: 0.0001.
78
+ window_size (`int`, *optional*):
79
+ The window size used for sliding window attention. Default: 2048.
80
+ rescale_prenorm_residual (`bool`, *optional*):
81
+ Whether or not to rescale `out_proj` weights when initializing. Default: `False`.
82
+ use_cache (`bool`, *optional*):
83
+ Whether or not the cache should be used. Default: `True`.
84
+
85
+
86
+ Example:
87
+
88
+ ```python
89
+ >>> from transformers import MambaConfig, MambaModel
90
+
91
+ >>> # Initializing a Mamba configuration
92
+ >>> configuration = MambaConfig()
93
+
94
+ >>> # Initializing a model (with random weights) from the configuration
95
+ >>> model = MambaModel(configuration)
96
+
97
+ >>> # Accessing the model configuration
98
+ >>> configuration = model.config
99
+ ```"""
100
+
101
+ model_type = "mamba"
102
+
103
+ def __init__(
104
+ self,
105
+ vocab_size: int = 32000,
106
+ hidden_size: int = 2048,
107
+ state_size: int = 16,
108
+ num_hidden_layers: int = 48,
109
+ layer_norm_epsilon=1e-5,
110
+ pad_token_id: int = 0,
111
+ bos_token_id: int = 1,
112
+ eos_token_id: int = 2,
113
+ expand: int = 2,
114
+ conv_kernel: int = 4,
115
+ use_bias: bool = False,
116
+ use_conv_bias: bool = True,
117
+ hidden_act: str = "silu",
118
+ initializer_range: str = 0.1,
119
+ residual_in_fp32: bool = False,
120
+ time_step_rank: str = "auto",
121
+ time_step_scale: float = 1.0,
122
+ time_step_min: float = 0.001,
123
+ time_step_max: float = 0.1,
124
+ time_step_init_scheme: str = "random",
125
+ time_step_floor: float = 1e-4,
126
+ rescale_prenorm_residual: bool = False,
127
+ use_cache: bool = True,
128
+ fuse_norm: bool = True,
129
+ fuse_cross_entropy: bool = True,
130
+ tie_word_embeddings: bool = False,
131
+ **kwargs,
132
+ ):
133
+ self.vocab_size = vocab_size
134
+ self.hidden_size = hidden_size
135
+ self.state_size = state_size
136
+ self.num_hidden_layers = num_hidden_layers
137
+ self.layer_norm_epsilon = layer_norm_epsilon
138
+ self.conv_kernel = conv_kernel
139
+ self.expand = expand
140
+ self.intermediate_size = int(expand * self.hidden_size)
141
+ self.bos_token_id = bos_token_id
142
+ self.eos_token_id = eos_token_id
143
+ self.pad_token_id = pad_token_id
144
+ self.use_bias = use_bias
145
+ self.use_conv_bias = use_conv_bias
146
+ self.hidden_act = hidden_act
147
+ self.initializer_range = initializer_range
148
+ self.time_step_rank = math.ceil(self.hidden_size / 16) if time_step_rank == "auto" else time_step_rank
149
+ self.time_step_scale = time_step_scale
150
+ self.time_step_min = time_step_min
151
+ self.time_step_max = time_step_max
152
+ self.time_step_init_scheme = time_step_init_scheme
153
+ self.time_step_floor = time_step_floor
154
+ self.rescale_prenorm_residual = rescale_prenorm_residual
155
+ self.residual_in_fp32 = residual_in_fp32
156
+ self.use_cache = use_cache
157
+ self.fuse_norm = fuse_norm
158
+ self.fuse_cross_entropy = fuse_cross_entropy
159
+
160
+ super().__init__(
161
+ bos_token_id=bos_token_id,
162
+ eos_token_id=eos_token_id,
163
+ pad_token_id=pad_token_id,
164
+ tie_word_embeddings=tie_word_embeddings,
165
+ **kwargs
166
+ )
fla/models/nsa/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.nsa.configuration_nsa import NSAConfig
6
+ from fla.models.nsa.modeling_nsa import NSAForCausalLM, NSAModel
7
+
8
+ AutoConfig.register(NSAConfig.model_type, NSAConfig)
9
+ AutoModel.register(NSAConfig, NSAModel)
10
+ AutoModelForCausalLM.register(NSAConfig, NSAForCausalLM)
11
+
12
+
13
+ __all__ = [
14
+ 'NSAConfig', 'NSAModel', 'NSAForCausalLM',
15
+ ]
fla/models/nsa/modeling_nsa.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from transformers.generation import GenerationMixin
13
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.utils import logging
16
+ from transformers.utils.deprecation import deprecate_kwarg
17
+
18
+ from fla.layers.nsa import NativeSparseAttention
19
+ from fla.models.nsa.configuration_nsa import NSAConfig
20
+ from fla.models.utils import Cache
21
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss
22
+ from fla.modules import GatedMLP as NSAMLP
23
+ from fla.modules import RMSNorm
24
+
25
+ if TYPE_CHECKING:
26
+ from transformers.processing_utils import Unpack
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+
31
+ class NSABlock(nn.Module):
32
+ def __init__(self, config: NSAConfig, layer_idx: int):
33
+ super().__init__()
34
+
35
+ self.config = config
36
+ self.layer_idx = layer_idx
37
+
38
+ self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
39
+ self.attn = NativeSparseAttention(
40
+ hidden_size=config.hidden_size,
41
+ num_heads=config.num_heads,
42
+ num_kv_heads=config.num_kv_heads,
43
+ qkv_bias=config.qkv_bias,
44
+ block_size=config.block_size,
45
+ block_counts=config.block_counts,
46
+ window_size=config.window_size,
47
+ rope_theta=config.rope_theta,
48
+ max_position_embeddings=config.max_position_embeddings,
49
+ layer_idx=layer_idx
50
+ )
51
+ self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
52
+ self.mlp = NSAMLP(
53
+ hidden_size=config.hidden_size,
54
+ hidden_ratio=config.hidden_ratio,
55
+ intermediate_size=config.intermediate_size,
56
+ hidden_act=config.hidden_act,
57
+ fuse_swiglu=config.fuse_swiglu
58
+ )
59
+
60
+ def forward(
61
+ self,
62
+ hidden_states: torch.Tensor,
63
+ attention_mask: Optional[torch.Tensor] = None,
64
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
65
+ use_cache: Optional[bool] = False,
66
+ output_attentions: Optional[bool] = False,
67
+ **kwargs: Unpack[Dict]
68
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
69
+ residual = hidden_states
70
+ hidden_states = self.attn_norm(hidden_states)
71
+ hidden_states, attentions, past_key_values = self.attn(
72
+ hidden_states=hidden_states,
73
+ attention_mask=attention_mask,
74
+ past_key_values=past_key_values,
75
+ use_cache=use_cache,
76
+ output_attentions=output_attentions,
77
+ **kwargs
78
+ )
79
+ if self.config.fuse_norm:
80
+ hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
81
+ else:
82
+ hidden_states = residual + hidden_states
83
+ residual = hidden_states
84
+ hidden_states = self.mlp_norm(hidden_states)
85
+ hidden_states = self.mlp(hidden_states, **kwargs)
86
+ hidden_states = residual + hidden_states
87
+
88
+ outputs = (hidden_states, attentions, past_key_values)
89
+
90
+ return outputs
91
+
92
+
93
+ class NSAPreTrainedModel(PreTrainedModel):
94
+
95
+ config_class = NSAConfig
96
+ base_model_prefix = 'model'
97
+ supports_gradient_checkpointing = True
98
+ _no_split_modules = ['NSABlock']
99
+ _supports_cache_class = True
100
+
101
+ def __init__(self, *inputs, **kwargs):
102
+ super().__init__(*inputs, **kwargs)
103
+
104
+ def _init_weights(
105
+ self,
106
+ module: nn.Module,
107
+ prenorm_residual_strategy: Optional[str] = 'rescale',
108
+ num_residuals_per_layer: int = 2,
109
+ ):
110
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
111
+ # Slightly different from the TF version which uses truncated_normal for initialization
112
+ # cf https://github.com/pytorch/pytorch/pull/5617
113
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
114
+ if module.bias is not None:
115
+ nn.init.zeros_(module.bias)
116
+ elif isinstance(module, nn.Embedding):
117
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
118
+ elif hasattr(module, 'reset_parameters'):
119
+ module.reset_parameters()
120
+
121
+ if prenorm_residual_strategy is not None:
122
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
123
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
124
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
125
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
126
+ #
127
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
128
+ p = None
129
+ if hasattr(module, 'o_proj'):
130
+ p = module.o_proj.weight
131
+ elif hasattr(module, 'down_proj'):
132
+ p = module.down_proj.weight
133
+ if p is not None:
134
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
135
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
136
+ # We need to reinit p since this code could be called multiple times
137
+ # Having just p *= scale would repeatedly scale it down
138
+ if prenorm_residual_strategy == 'rescale':
139
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
140
+ with torch.no_grad():
141
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
142
+ elif prenorm_residual_strategy == 'zero':
143
+ nn.init.zeros_(p)
144
+ else:
145
+ raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}")
146
+
147
+
148
+ class NSAModel(NSAPreTrainedModel):
149
+
150
+ def __init__(self, config: NSAConfig):
151
+ super().__init__(config)
152
+ self.padding_idx = config.pad_token_id
153
+ self.vocab_size = config.vocab_size
154
+
155
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
156
+ self.layers = nn.ModuleList([NSABlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
157
+ self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
158
+
159
+ self.gradient_checkpointing = False
160
+
161
+ self.post_init()
162
+
163
+ def get_input_embeddings(self):
164
+ return self.embeddings
165
+
166
+ def set_input_embeddings(self, value):
167
+ self.embeddings = value
168
+
169
+ def forward(
170
+ self,
171
+ input_ids: Optional[torch.LongTensor] = None,
172
+ attention_mask: Optional[torch.Tensor] = None, # noqa
173
+ inputs_embeds: Optional[torch.FloatTensor] = None,
174
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
175
+ use_cache: Optional[bool] = None,
176
+ output_attentions: Optional[bool] = None,
177
+ output_hidden_states: Optional[bool] = None,
178
+ return_dict: Optional[bool] = None,
179
+ **kwargs: Unpack[Dict]
180
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
181
+ if output_attentions:
182
+ warnings.warn("`NSAModel` does not `output_attentions` now, setting it to `False`.")
183
+ output_attentions = False
184
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
185
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
186
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
187
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
188
+
189
+ # retrieve input_ids and inputs_embeds
190
+ if input_ids is not None and inputs_embeds is not None:
191
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
192
+ if input_ids is None and inputs_embeds is None:
193
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
194
+
195
+ if inputs_embeds is None:
196
+ inputs_embeds = self.embeddings(input_ids)
197
+ hidden_states = inputs_embeds
198
+
199
+ if use_cache and not isinstance(past_key_values, Cache):
200
+ past_key_values = Cache.from_legacy_cache(past_key_values)
201
+
202
+ if self.gradient_checkpointing and self.training and use_cache:
203
+ logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
204
+ use_cache = False
205
+
206
+ all_hidden_states = () if output_hidden_states else None
207
+ all_attns = () if output_attentions else None
208
+ for layer in self.layers:
209
+ if output_hidden_states:
210
+ all_hidden_states += (hidden_states,)
211
+
212
+ if self.gradient_checkpointing and self.training:
213
+ hidden_states, attentions, past_key_values = self._gradient_checkpointing_func(
214
+ layer.__call__,
215
+ hidden_states,
216
+ attention_mask,
217
+ past_key_values,
218
+ use_cache,
219
+ output_attentions,
220
+ **kwargs
221
+ )
222
+ else:
223
+ hidden_states, attentions, past_key_values = layer(
224
+ hidden_states,
225
+ attention_mask=attention_mask,
226
+ past_key_values=past_key_values,
227
+ use_cache=use_cache,
228
+ output_attentions=output_attentions,
229
+ **kwargs
230
+ )
231
+
232
+ if output_attentions:
233
+ all_attns += (attentions,)
234
+
235
+ hidden_states = self.norm(hidden_states)
236
+
237
+ # add hidden states from the last decoder layer
238
+ if output_hidden_states:
239
+ all_hidden_states += (hidden_states,)
240
+
241
+ if not return_dict:
242
+ return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None)
243
+ return BaseModelOutputWithPast(
244
+ last_hidden_state=hidden_states,
245
+ past_key_values=past_key_values,
246
+ hidden_states=all_hidden_states,
247
+ attentions=all_attns
248
+ )
249
+
250
+
251
+ class NSAForCausalLM(NSAPreTrainedModel, GenerationMixin):
252
+
253
+ _tied_weights_keys = ["lm_head.weight"]
254
+
255
+ def __init__(self, config):
256
+ super().__init__(config)
257
+ self.model = NSAModel(config)
258
+ self.vocab_size = config.vocab_size
259
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
260
+ self.criterion = None
261
+
262
+ # Initialize weights and apply final processing
263
+ self.post_init()
264
+
265
+ def get_input_embeddings(self):
266
+ return self.model.embeddings
267
+
268
+ def set_input_embeddings(self, value):
269
+ self.model.embeddings = value
270
+
271
+ def get_output_embeddings(self):
272
+ return self.lm_head
273
+
274
+ def set_output_embeddings(self, new_embeddings):
275
+ self.lm_head = new_embeddings
276
+
277
+ def set_decoder(self, decoder):
278
+ self.model = decoder
279
+
280
+ def get_decoder(self):
281
+ return self.model
282
+
283
+ def generate(self, *args, **kwargs):
284
+ try:
285
+ return super().generate(*args, **kwargs)
286
+ except AttributeError as exception:
287
+ if 'past_key_values' in str(exception):
288
+ raise AttributeError(
289
+ f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
290
+ f"which is not supported for {self.__class__.__name__}. "
291
+ f"Try another generation strategy instead. "
292
+ f"For the available generation strategies, check this doc: "
293
+ f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
294
+ )
295
+ else:
296
+ raise exception
297
+
298
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
299
+ def prepare_inputs_for_generation(
300
+ self,
301
+ input_ids: torch.LongTensor = None,
302
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
303
+ attention_mask: Optional[torch.Tensor] = None,
304
+ inputs_embeds: Optional[torch.Tensor] = None,
305
+ use_cache: bool = True,
306
+ logits_to_keep: Optional[int] = None,
307
+ **kwargs
308
+ ):
309
+ # only last token for `inputs_ids` if the `past_key_values` is not empty.
310
+ if past_key_values is not None and len(past_key_values) > 0:
311
+ input_ids = input_ids[:, -1:]
312
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
313
+ if inputs_embeds is not None and len(past_key_values) == 0:
314
+ model_inputs = {'inputs_embeds': inputs_embeds}
315
+ else:
316
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
317
+ # recompiles graphs as the stride of the inputs is a guard.
318
+ # Ref: https://github.com/huggingface/transformers/pull/29114
319
+ # TODO: use `next_tokens` directly instead.
320
+ model_inputs = {'input_ids': input_ids.contiguous()}
321
+
322
+ if logits_to_keep is not None:
323
+ model_inputs['logits_to_keep'] = logits_to_keep
324
+
325
+ model_inputs.update({
326
+ 'past_key_values': past_key_values,
327
+ 'use_cache': use_cache,
328
+ 'attention_mask': attention_mask,
329
+ })
330
+ return model_inputs
331
+
332
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
333
+ def forward(
334
+ self,
335
+ input_ids: torch.LongTensor = None,
336
+ attention_mask: Optional[torch.Tensor] = None,
337
+ inputs_embeds: Optional[torch.Tensor] = None,
338
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
339
+ labels: Optional[torch.LongTensor] = None,
340
+ use_cache: Optional[bool] = None,
341
+ output_attentions: Optional[bool] = None,
342
+ output_hidden_states: Optional[bool] = None,
343
+ return_dict: Optional[bool] = None,
344
+ logits_to_keep: Optional[int] = 0,
345
+ **kwargs: Unpack[Dict]
346
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
347
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
348
+ output_hidden_states = (
349
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
350
+ )
351
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
352
+
353
+ outputs = self.model(
354
+ input_ids=input_ids,
355
+ attention_mask=attention_mask,
356
+ inputs_embeds=inputs_embeds,
357
+ past_key_values=past_key_values,
358
+ use_cache=use_cache,
359
+ output_attentions=output_attentions,
360
+ output_hidden_states=output_hidden_states,
361
+ return_dict=return_dict,
362
+ **kwargs
363
+ )
364
+
365
+ hidden_states = outputs[0]
366
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
367
+
368
+ loss, logits = None, None
369
+ if not fuse_linear_and_cross_entropy or labels is None:
370
+ logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
371
+ if labels is not None:
372
+ if getattr(self, 'criterion', None) is None:
373
+ if fuse_linear_and_cross_entropy:
374
+ criterion = FusedLinearCrossEntropyLoss()
375
+ elif self.config.fuse_cross_entropy:
376
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
377
+ else:
378
+ criterion = nn.CrossEntropyLoss()
379
+ else:
380
+ criterion = self.criterion
381
+ labels = labels.to(hidden_states.device)
382
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
383
+ if fuse_linear_and_cross_entropy:
384
+ loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
385
+ else:
386
+ loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
387
+
388
+ if not return_dict:
389
+ output = (logits,) + outputs[1:]
390
+ return (loss,) + output if loss is not None else output
391
+
392
+ return CausalLMOutputWithPast(
393
+ loss=loss,
394
+ logits=logits,
395
+ past_key_values=outputs.past_key_values,
396
+ hidden_states=outputs.hidden_states,
397
+ attentions=outputs.attentions,
398
+ )
fla/models/retnet/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.retnet.configuration_retnet import RetNetConfig
6
+ from fla.models.retnet.modeling_retnet import RetNetForCausalLM, RetNetModel
7
+
8
+ AutoConfig.register(RetNetConfig.model_type, RetNetConfig)
9
+ AutoModel.register(RetNetConfig, RetNetModel)
10
+ AutoModelForCausalLM.register(RetNetConfig, RetNetForCausalLM)
11
+
12
+
13
+ __all__ = ['RetNetConfig', 'RetNetForCausalLM', 'RetNetModel']
fla/models/retnet/configuration_retnet.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Dict, Optional
6
+
7
+ from transformers.configuration_utils import PretrainedConfig
8
+
9
+
10
+ class RetNetConfig(PretrainedConfig):
11
+
12
+ model_type = 'retnet'
13
+ keys_to_ignore_at_inference = ['past_key_values']
14
+
15
+ def __init__(
16
+ self,
17
+ attn_mode: str = "chunk",
18
+ hidden_size: int = 2048,
19
+ expand_k: int = 1,
20
+ expand_v: int = 2,
21
+ hidden_ratio: Optional[int] = 2,
22
+ intermediate_size: Optional[int] = None,
23
+ num_hidden_layers: int = 24,
24
+ num_heads: int = 8,
25
+ num_kv_heads: Optional[int] = None,
26
+ feature_map: Optional[str] = None,
27
+ hidden_act: str = "swish",
28
+ use_short_conv: bool = False,
29
+ conv_size: int = 4,
30
+ use_output_gate: bool = True,
31
+ max_position_embeddings: int = 2048,
32
+ elementwise_affine: Optional[bool] = True,
33
+ norm_eps: float = 1e-6,
34
+ attn: Optional[Dict] = None,
35
+ use_cache: bool = True,
36
+ pad_token_id: int = None,
37
+ bos_token_id: int = 1,
38
+ eos_token_id: int = 2,
39
+ tie_word_embeddings: bool = False,
40
+ initializer_range: float = 0.006,
41
+ fuse_norm: bool = True,
42
+ fuse_swiglu: bool = True,
43
+ fuse_cross_entropy: bool = True,
44
+ vocab_size: int = 32000,
45
+ **kwargs
46
+ ) -> RetNetConfig:
47
+ self.attn_mode = attn_mode
48
+ self.hidden_size = hidden_size
49
+ self.expand_k = expand_k
50
+ self.expand_v = expand_v
51
+ self.hidden_ratio = hidden_ratio
52
+ self.intermediate_size = intermediate_size
53
+ self.num_hidden_layers = num_hidden_layers
54
+ self.num_heads = num_heads
55
+ self.num_kv_heads = num_kv_heads
56
+ self.feature_map = feature_map
57
+ self.hidden_act = hidden_act
58
+ self.use_short_conv = use_short_conv
59
+ self.conv_size = conv_size
60
+ self.use_output_gate = use_output_gate
61
+ self.hidden_act = hidden_act
62
+ self.max_position_embeddings = max_position_embeddings
63
+ self.elementwise_affine = elementwise_affine
64
+ self.norm_eps = norm_eps
65
+ self.attn = attn
66
+ self.use_cache = use_cache
67
+ self.initializer_range = initializer_range
68
+
69
+ self.fuse_norm = fuse_norm
70
+ self.fuse_swiglu = fuse_swiglu
71
+ self.fuse_cross_entropy = fuse_cross_entropy
72
+ self.vocab_size = vocab_size
73
+
74
+ if attn is not None:
75
+ if not isinstance(attn, Dict):
76
+ raise ValueError("attn must be a dictionary")
77
+ if 'layers' not in attn:
78
+ raise ValueError("Layer indices must be provided to initialize hybrid attention layers")
79
+ if 'num_heads' not in attn:
80
+ raise ValueError("Number of heads must be provided to initialize hybrid attention layers")
81
+ attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads'])
82
+ attn['qkv_bias'] = attn.get('qkv_bias', False)
83
+ attn['window_size'] = attn.get('window_size', None)
84
+ attn['rope_theta'] = attn.get('rope_theta', 10000.)
85
+
86
+ super().__init__(
87
+ pad_token_id=pad_token_id,
88
+ bos_token_id=bos_token_id,
89
+ eos_token_id=eos_token_id,
90
+ tie_word_embeddings=tie_word_embeddings,
91
+ **kwargs,
92
+ )
fla/models/retnet/modeling_retnet.py ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from transformers.generation import GenerationMixin
13
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.utils import logging
16
+ from transformers.utils.deprecation import deprecate_kwarg
17
+
18
+ from fla.layers.attn import Attention
19
+ from fla.layers.multiscale_retention import MultiScaleRetention
20
+ from fla.models.retnet.configuration_retnet import RetNetConfig
21
+ from fla.models.utils import Cache
22
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss
23
+ from fla.modules import GatedMLP as RetNetMLP
24
+ from fla.modules import RMSNorm
25
+
26
+ if TYPE_CHECKING:
27
+ from transformers.processing_utils import Unpack
28
+
29
+ logger = logging.get_logger(__name__)
30
+
31
+
32
+ class RetNetBlock(nn.Module):
33
+ def __init__(self, config: RetNetConfig, layer_idx: int):
34
+ super().__init__()
35
+
36
+ self.config = config
37
+ self.layer_idx = layer_idx
38
+
39
+ self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
40
+ if config.attn is not None and layer_idx in config.attn['layers']:
41
+ self.attn = Attention(
42
+ hidden_size=config.hidden_size,
43
+ num_heads=config.attn['num_heads'],
44
+ num_kv_heads=config.attn['num_kv_heads'],
45
+ qkv_bias=config.attn['qkv_bias'],
46
+ window_size=config.attn['window_size'],
47
+ rope_theta=config.attn['rope_theta'],
48
+ max_position_embeddings=config.max_position_embeddings,
49
+ layer_idx=layer_idx
50
+ )
51
+ else:
52
+ self.attn = MultiScaleRetention(
53
+ mode=config.attn_mode,
54
+ hidden_size=config.hidden_size,
55
+ expand_k=config.expand_k,
56
+ expand_v=config.expand_v,
57
+ num_heads=config.num_heads,
58
+ num_kv_heads=config.num_kv_heads,
59
+ feature_map=config.feature_map,
60
+ use_output_gate=config.use_output_gate,
61
+ gate_fn=config.hidden_act,
62
+ elementwise_affine=config.elementwise_affine,
63
+ norm_eps=config.norm_eps,
64
+ fuse_norm=config.fuse_norm,
65
+ layer_idx=layer_idx
66
+ )
67
+ self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
68
+ self.mlp = RetNetMLP(
69
+ hidden_size=config.hidden_size,
70
+ hidden_ratio=config.hidden_ratio,
71
+ intermediate_size=config.intermediate_size,
72
+ hidden_act=config.hidden_act,
73
+ fuse_swiglu=config.fuse_swiglu
74
+ )
75
+
76
+ def forward(
77
+ self,
78
+ hidden_states: torch.Tensor,
79
+ attention_mask: Optional[torch.Tensor] = None,
80
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
81
+ use_cache: Optional[bool] = False,
82
+ output_attentions: Optional[bool] = False,
83
+ **kwargs: Unpack[Dict]
84
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
85
+
86
+ residual = hidden_states
87
+
88
+ hidden_states = self.attn_norm(hidden_states)
89
+ hidden_states, attentions, past_key_values = self.attn(
90
+ hidden_states=hidden_states,
91
+ attention_mask=attention_mask,
92
+ past_key_values=past_key_values,
93
+ use_cache=use_cache,
94
+ output_attentions=output_attentions,
95
+ **kwargs
96
+ )
97
+ if self.config.fuse_norm:
98
+ hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
99
+ else:
100
+ hidden_states = residual + hidden_states
101
+ residual = hidden_states
102
+ hidden_states = self.mlp_norm(hidden_states)
103
+ hidden_states = self.mlp(hidden_states, **kwargs)
104
+ hidden_states = residual + hidden_states
105
+
106
+ outputs = (hidden_states, attentions, past_key_values)
107
+
108
+ return outputs
109
+
110
+
111
+ class RetNetPreTrainedModel(PreTrainedModel):
112
+
113
+ config_class = RetNetConfig
114
+ base_model_prefix = 'model'
115
+ supports_gradient_checkpointing = True
116
+ _no_split_modules = ['RetNetBlock']
117
+ _supports_cache_class = True
118
+
119
+ def __init__(self, *inputs, **kwargs):
120
+ super().__init__(*inputs, **kwargs)
121
+
122
+ def _init_weights(
123
+ self,
124
+ module: nn.Module,
125
+ prenorm_residual_strategy: Optional[str] = 'rescale',
126
+ num_residuals_per_layer: int = 2,
127
+ ):
128
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
129
+ # Slightly different from the TF version which uses truncated_normal for initialization
130
+ # cf https://github.com/pytorch/pytorch/pull/5617
131
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
132
+ if module.bias is not None:
133
+ nn.init.zeros_(module.bias)
134
+ elif isinstance(module, nn.Embedding):
135
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
136
+ elif hasattr(module, 'reset_parameters'):
137
+ module.reset_parameters()
138
+
139
+ if prenorm_residual_strategy is not None:
140
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
141
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
142
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
143
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
144
+ #
145
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
146
+ p = None
147
+ if hasattr(module, 'o_proj'):
148
+ p = module.o_proj.weight
149
+ elif hasattr(module, 'down_proj'):
150
+ p = module.down_proj.weight
151
+ if p is not None:
152
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
153
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
154
+ # We need to reinit p since this code could be called multiple times
155
+ # Having just p *= scale would repeatedly scale it down
156
+ if prenorm_residual_strategy == 'rescale':
157
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
158
+ with torch.no_grad():
159
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
160
+ elif prenorm_residual_strategy == 'zero':
161
+ nn.init.zeros_(p)
162
+ else:
163
+ raise ValueError(f"Invalid prenorm_residual_strategy: {prenorm_residual_strategy}")
164
+
165
+
166
+ class RetNetModel(RetNetPreTrainedModel):
167
+
168
+ def __init__(self, config: RetNetConfig):
169
+ super().__init__(config)
170
+ self.padding_idx = config.pad_token_id
171
+ self.vocab_size = config.vocab_size
172
+
173
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
174
+ self.layers = nn.ModuleList(
175
+ [RetNetBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
176
+ )
177
+ self.norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(config.hidden_size, eps=config.norm_eps)
178
+
179
+ self.gradient_checkpointing = False
180
+
181
+ self.post_init()
182
+
183
+ def get_input_embeddings(self):
184
+ return self.embeddings
185
+
186
+ def set_input_embeddings(self, value):
187
+ self.embeddings = value
188
+
189
+ def forward(
190
+ self,
191
+ input_ids: Optional[torch.LongTensor] = None,
192
+ attention_mask: Optional[torch.Tensor] = None, # noqa
193
+ inputs_embeds: Optional[torch.FloatTensor] = None,
194
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
195
+ use_cache: Optional[bool] = None,
196
+ output_attentions: Optional[bool] = None,
197
+ output_hidden_states: Optional[bool] = None,
198
+ return_dict: Optional[bool] = None,
199
+ **kwargs: Unpack[Dict]
200
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
201
+ if output_attentions:
202
+ warnings.warn(
203
+ "`RetNetModel` does not support output attention weights now, so `output_attentions` is set to `False`."
204
+ )
205
+ output_attentions = False
206
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
207
+ output_hidden_states = (
208
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
209
+ )
210
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
211
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
212
+
213
+ # retrieve input_ids and inputs_embeds
214
+ if input_ids is not None and inputs_embeds is not None:
215
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
216
+ if input_ids is None and inputs_embeds is None:
217
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
218
+
219
+ if inputs_embeds is None:
220
+ inputs_embeds = self.embeddings(input_ids)
221
+ hidden_states = inputs_embeds
222
+
223
+ if use_cache and not isinstance(past_key_values, Cache):
224
+ past_key_values = Cache.from_legacy_cache(past_key_values)
225
+
226
+ if self.gradient_checkpointing and self.training and use_cache:
227
+ logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
228
+ use_cache = False
229
+
230
+ all_hidden_states = () if output_hidden_states else None
231
+ all_attns = () if output_attentions else None
232
+ for layer in self.layers:
233
+ if output_hidden_states:
234
+ all_hidden_states += (hidden_states,)
235
+
236
+ if self.gradient_checkpointing and self.training:
237
+ hidden_states, attentions, past_key_values = self._gradient_checkpointing_func(
238
+ layer.__call__,
239
+ hidden_states,
240
+ attention_mask,
241
+ past_key_values,
242
+ use_cache,
243
+ output_attentions,
244
+ **kwargs
245
+ )
246
+ else:
247
+ hidden_states, attentions, past_key_values = layer(
248
+ hidden_states,
249
+ attention_mask=attention_mask,
250
+ past_key_values=past_key_values,
251
+ use_cache=use_cache,
252
+ output_attentions=output_attentions,
253
+ **kwargs
254
+ )
255
+
256
+ if output_attentions:
257
+ all_attns += (attentions,)
258
+
259
+ hidden_states = self.norm(hidden_states)
260
+
261
+ # add hidden states from the last decoder layer
262
+ if output_hidden_states:
263
+ all_hidden_states += (hidden_states,)
264
+
265
+ if not return_dict:
266
+ return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None)
267
+ return BaseModelOutputWithPast(
268
+ last_hidden_state=hidden_states,
269
+ past_key_values=past_key_values,
270
+ hidden_states=all_hidden_states,
271
+ attentions=all_attns
272
+ )
273
+
274
+
275
+ class RetNetForCausalLM(RetNetPreTrainedModel, GenerationMixin):
276
+
277
+ _tied_weights_keys = ["lm_head.weight"]
278
+
279
+ def __init__(self, config):
280
+ super().__init__(config)
281
+ self.model = RetNetModel(config)
282
+ self.vocab_size = config.vocab_size
283
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
284
+ self.criterion = None
285
+
286
+ # Initialize weights and apply final processing
287
+ self.post_init()
288
+
289
+ def get_input_embeddings(self):
290
+ return self.model.embeddings
291
+
292
+ def set_input_embeddings(self, value):
293
+ self.model.embeddings = value
294
+
295
+ def get_output_embeddings(self):
296
+ return self.lm_head
297
+
298
+ def set_output_embeddings(self, new_embeddings):
299
+ self.lm_head = new_embeddings
300
+
301
+ def set_decoder(self, decoder):
302
+ self.model = decoder
303
+
304
+ def get_decoder(self):
305
+ return self.model
306
+
307
+ def generate(self, *args, **kwargs):
308
+ try:
309
+ return super().generate(*args, **kwargs)
310
+ except AttributeError as exception:
311
+ # Expected exception: "AttributeError: '(object name)' object has no attribute 'past_key_values'"
312
+ if 'past_key_values' in str(exception):
313
+ raise AttributeError(
314
+ f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
315
+ f"which is not supported for {self.__class__.__name__}. "
316
+ f"Try another generation strategy instead. "
317
+ f"For the available generation strategies, check this doc: "
318
+ f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
319
+ )
320
+ else:
321
+ raise exception
322
+
323
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
324
+ def prepare_inputs_for_generation(
325
+ self,
326
+ input_ids: torch.LongTensor = None,
327
+ past_key_values: Optional[torch.Tensor] = None,
328
+ attention_mask: Optional[torch.Tensor] = None,
329
+ inputs_embeds: Optional[torch.FloatTensor] = None,
330
+ use_cache: Optional[bool] = True,
331
+ logits_to_keep: Optional[int] = None,
332
+ **kwargs: Unpack[Dict]
333
+ ):
334
+ # only last token for `inputs_ids` if the `past_key_values` is passed along.
335
+ if past_key_values is not None:
336
+ input_ids = input_ids[:, -1:]
337
+
338
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
339
+ if inputs_embeds is not None and len(past_key_values) == 0:
340
+ model_inputs = {'inputs_embeds': inputs_embeds}
341
+ else:
342
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
343
+ # recompiles graphs as the stride of the inputs is a guard.
344
+ # Ref: https://github.com/huggingface/transformers/pull/29114
345
+ # TODO: use `next_tokens` directly instead.
346
+ model_inputs = {'input_ids': input_ids.contiguous()}
347
+
348
+ if logits_to_keep is not None:
349
+ model_inputs['logits_to_keep'] = logits_to_keep
350
+
351
+ model_inputs.update({
352
+ 'past_key_values': past_key_values,
353
+ 'use_cache': use_cache,
354
+ 'attention_mask': attention_mask,
355
+ })
356
+ return model_inputs
357
+
358
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
359
+ def forward(
360
+ self,
361
+ input_ids: torch.LongTensor = None,
362
+ attention_mask: Optional[torch.Tensor] = None,
363
+ inputs_embeds: Optional[torch.FloatTensor] = None,
364
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
365
+ labels: Optional[torch.LongTensor] = None,
366
+ use_cache: Optional[bool] = None,
367
+ output_attentions: Optional[bool] = None,
368
+ output_hidden_states: Optional[bool] = None,
369
+ return_dict: Optional[bool] = None,
370
+ logits_to_keep: Optional[int] = 0,
371
+ **kwargs: Unpack[Dict]
372
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
373
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
374
+ output_hidden_states = (
375
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
376
+ )
377
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
378
+
379
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
380
+ outputs = self.model(
381
+ input_ids=input_ids,
382
+ attention_mask=attention_mask,
383
+ inputs_embeds=inputs_embeds,
384
+ past_key_values=past_key_values,
385
+ use_cache=use_cache,
386
+ output_attentions=output_attentions,
387
+ output_hidden_states=output_hidden_states,
388
+ return_dict=return_dict,
389
+ **kwargs
390
+ )
391
+
392
+ hidden_states = outputs[0]
393
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
394
+
395
+ loss, logits = None, None
396
+ if not fuse_linear_and_cross_entropy or labels is None:
397
+ logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
398
+ if labels is not None:
399
+ if getattr(self, 'criterion', None) is None:
400
+ if fuse_linear_and_cross_entropy:
401
+ criterion = FusedLinearCrossEntropyLoss()
402
+ elif self.config.fuse_cross_entropy:
403
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
404
+ else:
405
+ criterion = nn.CrossEntropyLoss()
406
+ else:
407
+ criterion = self.criterion
408
+ labels = labels.to(hidden_states.device)
409
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
410
+ if fuse_linear_and_cross_entropy:
411
+ loss = criterion(hidden_states, labels, self.lm_head.weight, self.lm_head.bias)
412
+ else:
413
+ loss = criterion(logits.view(labels.numel(), -1), labels.view(-1))
414
+
415
+ if not return_dict:
416
+ output = (logits,) + outputs[1:]
417
+ return (loss,) + output if loss is not None else output
418
+
419
+ return CausalLMOutputWithPast(
420
+ loss=loss,
421
+ logits=logits,
422
+ past_key_values=outputs.past_key_values,
423
+ hidden_states=outputs.hidden_states,
424
+ attentions=outputs.attentions,
425
+ )
fla/models/rwkv6/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (687 Bytes). View file
 
fla/models/rwkv6/__pycache__/modeling_rwkv6.cpython-312.pyc ADDED
Binary file (21.2 kB). View file
 
fla/models/rwkv7/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (687 Bytes). View file
 
fla/models/rwkv7/__pycache__/modeling_rwkv7.cpython-312.pyc ADDED
Binary file (22.3 kB). View file
 
fla/models/rwkv7/modeling_rwkv7.py ADDED
@@ -0,0 +1,505 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from transformers.generation import GenerationMixin
13
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.utils import logging
16
+ from transformers.utils.deprecation import deprecate_kwarg
17
+
18
+ from fla.layers.attn import Attention
19
+ from fla.layers.rwkv7 import RWKV7Attention
20
+ from fla.models.rwkv7.configuration_rwkv7 import RWKV7Config
21
+ from fla.models.utils import Cache
22
+ from fla.modules import FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss, LayerNorm
23
+ from fla.modules.activations import ACT2FN
24
+
25
+ if TYPE_CHECKING:
26
+ from transformers.processing_utils import Unpack
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+
31
+ class RWKV7FeedForward(nn.Module):
32
+
33
+ def __init__(
34
+ self,
35
+ hidden_size: int,
36
+ hidden_ratio: Optional[int] = None,
37
+ intermediate_size: Optional[int] = None,
38
+ hidden_act: str = 'sqrelu',
39
+ layer_idx: int = None
40
+ ) -> RWKV7FeedForward:
41
+ super().__init__()
42
+
43
+ self.hidden_size = hidden_size
44
+ if hidden_ratio is None:
45
+ hidden_ratio = 4
46
+ if intermediate_size is None:
47
+ intermediate_size = int(hidden_size * hidden_ratio)
48
+ intermediate_size = 32 * ((intermediate_size + 32 - 1) // 32)
49
+ self.hidden_ratio = hidden_ratio
50
+ self.intermediate_size = intermediate_size
51
+
52
+ self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
53
+
54
+ self.x_k = nn.Parameter(torch.zeros(hidden_size))
55
+
56
+ self.key = nn.Linear(hidden_size, intermediate_size, bias=False)
57
+ self.value = nn.Linear(intermediate_size, hidden_size, bias=False)
58
+ self.act_fn = ACT2FN[hidden_act]
59
+
60
+ self.layer_idx = layer_idx
61
+
62
+ def forward(
63
+ self,
64
+ x: torch.Tensor,
65
+ attention_mask: Optional[torch.Tensor] = None,
66
+ state: Optional[Cache] = None
67
+ ) -> torch.Tensor:
68
+ if attention_mask is not None:
69
+ x = x.mul(attention_mask[:, -x.shape[-2]:, None])
70
+ if x.shape[1] == 1 and state is not None and state[self.layer_idx]['ffn_state'] is not None:
71
+ shifted = state[self.layer_idx]['ffn_state'].unsqueeze(1)
72
+ else:
73
+ shifted = self.time_shift(x)
74
+ if state is not None and state[self.layer_idx]['ffn_state'] is not None:
75
+ shifted[:, 0] = state[self.layer_idx]['ffn_state'][-1]
76
+ if state is not None:
77
+ # no need to update the offset twice
78
+ state.update(ffn_state=x[:, -1], layer_idx=self.layer_idx, offset=0)
79
+ return self.value(self.act_fn(self.key(x.addcmul(shifted - x, self.x_k)))), state
80
+
81
+
82
+ class RWKV7Block(nn.Module):
83
+
84
+ def __init__(
85
+ self,
86
+ config: RWKV7Config,
87
+ layer_idx: int
88
+ ) -> RWKV7Block:
89
+ super().__init__()
90
+
91
+ self.config = config
92
+ self.layer_idx = layer_idx
93
+
94
+ if config.norm_first and layer_idx == 0:
95
+ self.pre_norm = (LayerNorm if config.fuse_norm else nn.LayerNorm)(
96
+ config.hidden_size,
97
+ bias=config.norm_bias,
98
+ eps=config.norm_eps
99
+ )
100
+ self.attn_norm = (LayerNorm if config.fuse_norm else nn.LayerNorm)(
101
+ config.hidden_size,
102
+ bias=config.norm_bias,
103
+ eps=config.norm_eps
104
+ )
105
+ if config.attn is not None and layer_idx in config.attn['layers']:
106
+ self.attn = Attention(
107
+ hidden_size=config.hidden_size,
108
+ num_heads=config.attn['num_heads'],
109
+ num_kv_heads=config.attn['num_kv_heads'],
110
+ qkv_bias=config.attn['qkv_bias'],
111
+ window_size=config.attn['window_size'],
112
+ rope_theta=config.attn['rope_theta'],
113
+ max_position_embeddings=config.max_position_embeddings,
114
+ layer_idx=layer_idx
115
+ )
116
+ else:
117
+ self.attn = RWKV7Attention(
118
+ mode=config.attn_mode,
119
+ hidden_size=config.hidden_size,
120
+ head_dim=config.head_dim,
121
+ num_heads=config.num_heads,
122
+ decay_low_rank_dim=config.decay_low_rank_dim,
123
+ gate_low_rank_dim=config.gate_low_rank_dim,
124
+ a_low_rank_dim=config.a_low_rank_dim,
125
+ v_low_rank_dim=config.v_low_rank_dim,
126
+ norm_eps=config.norm_eps,
127
+ fuse_norm=config.fuse_norm,
128
+ layer_idx=layer_idx,
129
+ value_dim=config.value_dim[layer_idx]
130
+ )
131
+ self.ffn_norm = (LayerNorm if config.fuse_norm else nn.LayerNorm)(
132
+ config.hidden_size,
133
+ bias=config.norm_bias,
134
+ eps=config.norm_eps
135
+ )
136
+ self.ffn = RWKV7FeedForward(
137
+ hidden_size=config.hidden_size,
138
+ hidden_ratio=config.hidden_ratio,
139
+ intermediate_size=config.intermediate_size,
140
+ hidden_act=config.hidden_act,
141
+ layer_idx=layer_idx
142
+ )
143
+
144
+ def forward(
145
+ self,
146
+ hidden_states: torch.Tensor,
147
+ attention_mask: Optional[torch.Tensor] = None,
148
+ past_key_values: Optional[Cache] = None,
149
+ use_cache: Optional[bool] = False,
150
+ output_attentions: Optional[bool] = False,
151
+ v_first: torch.Tensor = None,
152
+ **kwargs,
153
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
154
+ residual = self.pre_norm(hidden_states) if hasattr(self, 'pre_norm') else hidden_states
155
+ hidden_states = self.attn_norm(residual)
156
+ hidden_states, attentions, past_key_values, v_first = self.attn(
157
+ hidden_states=hidden_states,
158
+ attention_mask=attention_mask,
159
+ past_key_values=past_key_values,
160
+ use_cache=use_cache,
161
+ output_attentions=output_attentions,
162
+ v_first=v_first,
163
+ **kwargs
164
+ )
165
+ if self.config.fuse_norm:
166
+ hidden_states, residual = self.ffn_norm(hidden_states, residual, True)
167
+ else:
168
+ hidden_states = residual + hidden_states
169
+ residual = hidden_states
170
+ hidden_states = self.ffn_norm(hidden_states)
171
+ hidden_states, past_key_values = self.ffn(hidden_states, attention_mask, past_key_values)
172
+ hidden_states = residual + hidden_states
173
+
174
+ outputs = (hidden_states, attentions, past_key_values, v_first)
175
+
176
+ return outputs
177
+
178
+
179
+ class RWKV7PreTrainedModel(PreTrainedModel):
180
+
181
+ config_class = RWKV7Config
182
+ base_model_prefix = 'model'
183
+ supports_gradient_checkpointing = True
184
+ _no_split_modules = ['RWKV7Block']
185
+ _supports_cache_class = True
186
+ _skip_keys_device_placement = ["past_key_values"]
187
+
188
+ def __init__(self, *inputs, **kwargs):
189
+ super().__init__(*inputs, **kwargs)
190
+
191
+ def _init_weights(
192
+ self,
193
+ module: nn.Module,
194
+ rescale_prenorm_residual: bool = True,
195
+ num_residuals_per_layer: int = 2,
196
+ ):
197
+ warnings.warn(
198
+ "RWKV-7 employs a carefully designed initialization strategy tailored to its architecture. "
199
+ "The detailed initialization scheme is currently not implemented here but can be found in the "
200
+ "official code repository. We emphasize that using the recommended initialization is essential "
201
+ "for replicating the results in RWKV-7 paper. Deviations from the prescribed initialization "
202
+ "may lead to performance degradation.\n"
203
+ "Alternatively, please generate initial weights from the official RWKV code repository, and "
204
+ "convert the PyTorch checkpoint into FLA supported format."
205
+ )
206
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
207
+ # Slightly different from the TF version which uses truncated_normal for initialization
208
+ # cf https://github.com/pytorch/pytorch/pull/5617
209
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
210
+ if module.bias is not None:
211
+ nn.init.zeros_(module.bias)
212
+ elif isinstance(module, nn.Parameter):
213
+ nn.init.normal_(module, mean=0.0, std=self.config.initializer_range)
214
+ elif isinstance(module, nn.Embedding):
215
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
216
+ elif hasattr(module, 'reset_parameters'):
217
+ module.reset_parameters()
218
+
219
+ if rescale_prenorm_residual:
220
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
221
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
222
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
223
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
224
+ #
225
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
226
+ p = None
227
+ if hasattr(module, 'o_proj'):
228
+ p = module.o_proj.weight
229
+ elif hasattr(module, 'down_proj'):
230
+ p = module.down_proj.weight
231
+ if p is not None:
232
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
233
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
234
+ # We need to reinit p since this code could be called multiple times
235
+ # Having just p *= scale would repeatedly scale it down
236
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
237
+ with torch.no_grad():
238
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
239
+
240
+
241
+ class RWKV7Model(RWKV7PreTrainedModel):
242
+
243
+ def __init__(self, config: RWKV7Config):
244
+ super().__init__(config)
245
+ self.padding_idx = config.pad_token_id
246
+ self.vocab_size = config.vocab_size
247
+
248
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
249
+ self.layers = nn.ModuleList([RWKV7Block(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
250
+ self.norm = (LayerNorm if config.fuse_norm else nn.LayerNorm)(
251
+ config.hidden_size,
252
+ bias=config.norm_bias,
253
+ eps=config.norm_eps
254
+ )
255
+
256
+ self.gradient_checkpointing = False
257
+
258
+ self.post_init()
259
+
260
+ def get_input_embeddings(self):
261
+ return self.embeddings
262
+
263
+ def set_input_embeddings(self, value):
264
+ self.embeddings = value
265
+
266
+ def forward(
267
+ self,
268
+ input_ids: Optional[torch.LongTensor] = None,
269
+ attention_mask: Optional[torch.Tensor] = None, # noqa
270
+ inputs_embeds: Optional[torch.FloatTensor] = None,
271
+ past_key_values: Optional[Cache] = None,
272
+ use_cache: Optional[bool] = None,
273
+ output_attentions: Optional[bool] = None,
274
+ output_hidden_states: Optional[bool] = None,
275
+ return_dict: Optional[bool] = None,
276
+ **kwargs: Unpack[Dict]
277
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
278
+ if output_attentions:
279
+ warnings.warn("`RWKV7Model` does not `output_attentions` now, setting it to `False`.")
280
+ output_attentions = False
281
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
282
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
283
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
284
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
285
+
286
+ # retrieve input_ids and inputs_embeds
287
+ if input_ids is not None and inputs_embeds is not None:
288
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
289
+ if input_ids is None and inputs_embeds is None:
290
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
291
+
292
+ if inputs_embeds is None:
293
+ inputs_embeds = self.embeddings(input_ids)
294
+ hidden_states = inputs_embeds
295
+
296
+ if use_cache and not isinstance(past_key_values, Cache):
297
+ past_key_values = Cache.from_legacy_cache(past_key_values)
298
+
299
+ if self.gradient_checkpointing and self.training and use_cache:
300
+ logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
301
+ use_cache = False
302
+
303
+ all_hidden_states = () if output_hidden_states else None
304
+ all_attns = () if output_attentions else None
305
+
306
+ v_first = torch.zeros_like(hidden_states)
307
+ for layer in self.layers:
308
+ if output_hidden_states:
309
+ all_hidden_states += (hidden_states,)
310
+
311
+ if self.gradient_checkpointing and self.training:
312
+ hidden_states, attentions, past_key_values, v_first = self._gradient_checkpointing_func(
313
+ layer.__call__,
314
+ hidden_states,
315
+ attention_mask,
316
+ past_key_values,
317
+ use_cache,
318
+ output_attentions,
319
+ v_first,
320
+ **kwargs
321
+ )
322
+ else:
323
+ hidden_states, attentions, past_key_values, v_first = layer(
324
+ hidden_states,
325
+ attention_mask=attention_mask,
326
+ past_key_values=past_key_values,
327
+ use_cache=use_cache,
328
+ output_attentions=output_attentions,
329
+ v_first=v_first,
330
+ **kwargs
331
+ )
332
+
333
+ if output_attentions:
334
+ all_attns += (attentions,)
335
+
336
+ hidden_states = self.norm(hidden_states)
337
+
338
+ # add hidden states from the last decoder layer
339
+ if output_hidden_states:
340
+ all_hidden_states += (hidden_states,)
341
+
342
+ if not return_dict:
343
+ return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None)
344
+ return BaseModelOutputWithPast(
345
+ last_hidden_state=hidden_states,
346
+ past_key_values=past_key_values,
347
+ hidden_states=all_hidden_states,
348
+ attentions=all_attns
349
+ )
350
+
351
+
352
+ class RWKV7ForCausalLM(RWKV7PreTrainedModel, GenerationMixin):
353
+
354
+ _tied_weights_keys = ["lm_head.weight"]
355
+
356
+ def __init__(self, config):
357
+ super().__init__(config)
358
+ self.model = RWKV7Model(config)
359
+ self.vocab_size = config.vocab_size
360
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
361
+ self.criterion = None
362
+
363
+ # Initialize weights and apply final processing
364
+ self.post_init()
365
+
366
+ def get_input_embeddings(self):
367
+ return self.model.embeddings
368
+
369
+ def set_input_embeddings(self, value):
370
+ self.model.embeddings = value
371
+
372
+ def get_output_embeddings(self):
373
+ return self.lm_head
374
+
375
+ def set_output_embeddings(self, new_embeddings):
376
+ self.lm_head = new_embeddings
377
+
378
+ def set_decoder(self, decoder):
379
+ self.model = decoder
380
+
381
+ def get_decoder(self):
382
+ return self.model
383
+
384
+ def generate(self, *args, **kwargs):
385
+ try:
386
+ return super().generate(*args, **kwargs)
387
+ except AttributeError as exception:
388
+ if 'past_key_values' in str(exception):
389
+ raise AttributeError(
390
+ f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
391
+ f"which is not supported for {self.__class__.__name__}. "
392
+ f"Try another generation strategy instead. "
393
+ f"For the available generation strategies, check this doc: "
394
+ f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
395
+ )
396
+ else:
397
+ raise exception
398
+
399
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
400
+ def prepare_inputs_for_generation(
401
+ self,
402
+ input_ids: torch.LongTensor = None,
403
+ past_key_values: Optional[Cache] = None,
404
+ attention_mask: Optional[torch.Tensor] = None,
405
+ inputs_embeds: Optional[torch.Tensor] = None,
406
+ use_cache: bool = True,
407
+ logits_to_keep: Optional[int] = None,
408
+ **kwargs
409
+ ):
410
+ # only last token for `inputs_ids` if the `past_key_values` is not empty.
411
+ if past_key_values is not None and len(past_key_values) > 0:
412
+ input_ids = input_ids[:, -1:]
413
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
414
+ if inputs_embeds is not None and len(past_key_values) == 0:
415
+ model_inputs = {'inputs_embeds': inputs_embeds}
416
+ else:
417
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
418
+ # recompiles graphs as the stride of the inputs is a guard.
419
+ # Ref: https://github.com/huggingface/transformers/pull/29114
420
+ # TODO: use `next_tokens` directly instead.
421
+ model_inputs = {'input_ids': input_ids.contiguous()}
422
+
423
+ if logits_to_keep is not None:
424
+ model_inputs['logits_to_keep'] = logits_to_keep
425
+
426
+ model_inputs.update({
427
+ 'past_key_values': past_key_values,
428
+ 'use_cache': use_cache,
429
+ 'attention_mask': attention_mask,
430
+ })
431
+ return model_inputs
432
+
433
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
434
+ def forward(
435
+ self,
436
+ input_ids: torch.LongTensor = None,
437
+ attention_mask: Optional[torch.Tensor] = None,
438
+ inputs_embeds: Optional[torch.Tensor] = None,
439
+ past_key_values: Optional[Cache] = None,
440
+ labels: Optional[torch.LongTensor] = None,
441
+ shift_labels: Optional[torch.LongTensor] = None,
442
+ use_cache: Optional[bool] = None,
443
+ output_attentions: Optional[bool] = None,
444
+ output_hidden_states: Optional[bool] = None,
445
+ return_dict: Optional[bool] = None,
446
+ logits_to_keep: Optional[int] = 0,
447
+ **kwargs: Unpack[Dict]
448
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
449
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
450
+ output_hidden_states = (
451
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
452
+ )
453
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
454
+
455
+ outputs = self.model(
456
+ input_ids=input_ids,
457
+ attention_mask=attention_mask,
458
+ inputs_embeds=inputs_embeds,
459
+ past_key_values=past_key_values,
460
+ use_cache=use_cache,
461
+ output_attentions=output_attentions,
462
+ output_hidden_states=output_hidden_states,
463
+ return_dict=return_dict,
464
+ **kwargs
465
+ )
466
+
467
+ hidden_states = outputs[0]
468
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
469
+
470
+ loss, logits = None, None
471
+ has_labels = (labels is not None) or (shift_labels is not None)
472
+ if not (fuse_linear_and_cross_entropy and has_labels):
473
+ logits = self.lm_head(hidden_states if logits_to_keep is None else hidden_states[:, -logits_to_keep:])
474
+ if has_labels:
475
+ if getattr(self, 'criterion', None) is None:
476
+ if fuse_linear_and_cross_entropy:
477
+ criterion = FusedLinearCrossEntropyLoss()
478
+ elif self.config.fuse_cross_entropy:
479
+ criterion = FusedCrossEntropyLoss(inplace_backward=True)
480
+ else:
481
+ criterion = nn.CrossEntropyLoss()
482
+ else:
483
+ criterion = self.criterion
484
+
485
+ # shift_labels: See https://github.com/huggingface/transformers/pull/36607/files.
486
+ if shift_labels is None:
487
+ shift_labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], criterion.ignore_index)), 1)
488
+ shift_labels = shift_labels.to(hidden_states.device)
489
+
490
+ if fuse_linear_and_cross_entropy:
491
+ loss = criterion(hidden_states, shift_labels, self.lm_head.weight, self.lm_head.bias)
492
+ else:
493
+ loss = criterion(logits.view(shift_labels.numel(), -1), shift_labels.view(-1))
494
+
495
+ if not return_dict:
496
+ output = (logits,) + outputs[1:]
497
+ return (loss,) + output if loss is not None else output
498
+
499
+ return CausalLMOutputWithPast(
500
+ loss=loss,
501
+ logits=logits,
502
+ past_key_values=outputs.past_key_values,
503
+ hidden_states=outputs.hidden_states,
504
+ attentions=outputs.attentions,
505
+ )
fla/models/samba/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.samba.configuration_samba import SambaConfig
6
+ from fla.models.samba.modeling_samba import SambaBlock, SambaForCausalLM, SambaModel
7
+
8
+ AutoConfig.register(SambaConfig.model_type, SambaConfig, True)
9
+ AutoModel.register(SambaConfig, SambaModel, True)
10
+ AutoModelForCausalLM.register(SambaConfig, SambaForCausalLM, True)
11
+
12
+
13
+ __all__ = ['SambaConfig', 'SambaForCausalLM', 'SambaModel', 'SambaBlock']
fla/models/samba/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (717 Bytes). View file