zaydzuhri commited on
Commit
0b38110
·
verified ·
1 Parent(s): 722383d

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. fla/ops/abc/__pycache__/__init__.cpython-312.pyc +0 -0
  2. fla/ops/attn/__pycache__/__init__.cpython-312.pyc +0 -0
  3. fla/ops/attn/parallel.py +629 -0
  4. fla/ops/based/__pycache__/__init__.cpython-312.pyc +0 -0
  5. fla/ops/common/__pycache__/chunk_delta_h.cpython-312.pyc +0 -0
  6. fla/ops/common/chunk_delta_h.py +399 -0
  7. fla/ops/common/chunk_h_parallel.py +650 -0
  8. fla/ops/common/fused_recurrent.py +575 -0
  9. fla/ops/common/utils.py +69 -0
  10. fla/ops/delta_rule/README.md +90 -0
  11. fla/ops/delta_rule/__init__.py +11 -0
  12. fla/ops/delta_rule/parallel.py +394 -0
  13. fla/ops/gated_delta_rule/__init__.py +7 -0
  14. fla/ops/gated_delta_rule/__pycache__/chunk.cpython-312.pyc +0 -0
  15. fla/ops/gated_delta_rule/__pycache__/fused_recurrent.cpython-312.pyc +0 -0
  16. fla/ops/generalized_delta_rule/__init__.py +9 -0
  17. fla/ops/generalized_delta_rule/__pycache__/__init__.cpython-312.pyc +0 -0
  18. fla/ops/generalized_delta_rule/dplr/__init__.py +7 -0
  19. fla/ops/generalized_delta_rule/dplr/__pycache__/__init__.cpython-312.pyc +0 -0
  20. fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_A_bwd.cpython-312.pyc +0 -0
  21. fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_h_bwd.cpython-312.pyc +0 -0
  22. fla/ops/generalized_delta_rule/dplr/__pycache__/wy_fast_fwd.cpython-312.pyc +0 -0
  23. fla/ops/generalized_delta_rule/dplr/chunk_A_bwd.py +446 -0
  24. fla/ops/generalized_delta_rule/dplr/chunk_A_fwd.py +324 -0
  25. fla/ops/generalized_delta_rule/dplr/chunk_h_bwd.py +196 -0
  26. fla/ops/generalized_delta_rule/dplr/chunk_h_fwd.py +197 -0
  27. fla/ops/generalized_delta_rule/dplr/fused_recurrent.py +292 -0
  28. fla/ops/generalized_delta_rule/dplr/naive.py +96 -0
  29. fla/ops/generalized_delta_rule/iplr/__pycache__/__init__.cpython-312.pyc +0 -0
  30. fla/ops/generalized_delta_rule/iplr/__pycache__/wy_fast.cpython-312.pyc +0 -0
  31. fla/ops/generalized_delta_rule/iplr/wy_fast.py +338 -0
  32. fla/ops/gla/__pycache__/__init__.cpython-312.pyc +0 -0
  33. fla/ops/gla/__pycache__/chunk.cpython-312.pyc +0 -0
  34. fla/ops/gla/__pycache__/fused_chunk.cpython-312.pyc +0 -0
  35. fla/ops/gla/__pycache__/fused_recurrent.cpython-312.pyc +0 -0
  36. fla/ops/gsa/__init__.py +9 -0
  37. fla/ops/gsa/naive.py +68 -0
  38. fla/ops/hgrn/__pycache__/fused_recurrent.cpython-312.pyc +0 -0
  39. fla/ops/hgrn/chunk.py +282 -0
  40. fla/ops/hgrn/fused_recurrent.py +308 -0
  41. fla/ops/lightning_attn/__pycache__/fused_recurrent.cpython-312.pyc +0 -0
  42. fla/ops/linear_attn/__pycache__/__init__.cpython-312.pyc +0 -0
  43. fla/ops/linear_attn/__pycache__/utils.cpython-312.pyc +0 -0
  44. fla/ops/linear_attn/fused_recurrent.py +251 -0
  45. fla/ops/nsa/__pycache__/__init__.cpython-312.pyc +0 -0
  46. fla/ops/nsa/__pycache__/naive.cpython-312.pyc +0 -0
  47. fla/ops/nsa/__pycache__/parallel.cpython-312.pyc +0 -0
  48. fla/ops/rebased/__pycache__/__init__.cpython-312.pyc +0 -0
  49. fla/ops/rebased/parallel.py +466 -0
  50. fla/ops/retention/__pycache__/__init__.cpython-312.pyc +0 -0
fla/ops/abc/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (212 Bytes). View file
 
fla/ops/attn/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (220 Bytes). View file
 
fla/ops/attn/parallel.py ADDED
@@ -0,0 +1,629 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+ from einops import rearrange, reduce
10
+
11
+ from fla.ops.common.utils import prepare_chunk_indices
12
+ from fla.ops.utils.op import exp, log
13
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, check_shared_mem, contiguous
14
+
15
+
16
+ @triton.heuristics({
17
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
18
+ })
19
+ @triton.autotune(
20
+ configs=[
21
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
22
+ for num_warps in [1, 2, 4] + ([8] if check_shared_mem('hopper') else [])
23
+ for num_stages in [2, 3, 4, 5]
24
+ ],
25
+ key=['B', 'H', 'G', 'K', 'V', 'BK', 'BV'],
26
+ )
27
+ @triton.jit
28
+ def parallel_attn_fwd_kernel(
29
+ q,
30
+ k,
31
+ v,
32
+ o,
33
+ lse,
34
+ scale,
35
+ offsets,
36
+ indices,
37
+ T,
38
+ B: tl.constexpr,
39
+ H: tl.constexpr,
40
+ HQ: tl.constexpr,
41
+ G: tl.constexpr,
42
+ K: tl.constexpr,
43
+ V: tl.constexpr,
44
+ BT: tl.constexpr,
45
+ BS: tl.constexpr,
46
+ BK: tl.constexpr,
47
+ BV: tl.constexpr,
48
+ USE_OFFSETS: tl.constexpr
49
+ ):
50
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
51
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
52
+ i_h = i_hq // G
53
+
54
+ if USE_OFFSETS:
55
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
56
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
57
+ T = eos - bos
58
+ else:
59
+ i_n = i_b
60
+ bos, eos = i_n * T, i_n * T + T
61
+
62
+ p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
63
+ p_o = tl.make_block_ptr(o + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
64
+ p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,))
65
+
66
+ # the Q block is kept in the shared memory throughout the whole kernel
67
+ # [BT, BK]
68
+ b_q = tl.load(p_q, boundary_check=(0, 1))
69
+ b_q = (b_q * scale).to(b_q.dtype)
70
+ # [BT, BV]
71
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
72
+
73
+ b_m = tl.full([BT], float('-inf'), dtype=tl.float32)
74
+ b_acc = tl.zeros([BT], dtype=tl.float32)
75
+ for i_s in range(0, i_t * BT, BS):
76
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1))
77
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
78
+ # [BK, BS]
79
+ b_k = tl.load(p_k, boundary_check=(0, 1))
80
+ # [BS, BV]
81
+ b_v = tl.load(p_v, boundary_check=(0, 1))
82
+ # [BT, BS]
83
+ b_s = tl.dot(b_q, b_k)
84
+
85
+ # [BT, BS]
86
+ b_m, b_mp = tl.maximum(b_m, tl.max(b_s, 1)), b_m
87
+ b_r = exp(b_mp - b_m)
88
+ # [BT, BS]
89
+ b_p = exp(b_s - b_m[:, None])
90
+ # [BT]
91
+ b_acc = b_acc * b_r + tl.sum(b_p, 1)
92
+ # [BT, BV]
93
+ b_o = b_o * b_r[:, None] + tl.dot(b_p.to(b_q.dtype), b_v)
94
+
95
+ b_mp = b_m
96
+
97
+ # [BT]
98
+ o_q = i_t * BT + tl.arange(0, BT)
99
+ for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS):
100
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1))
101
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
102
+
103
+ # [BS]
104
+ o_k = i_s + tl.arange(0, BS)
105
+ # [BK, BS]
106
+ b_k = tl.load(p_k, boundary_check=(0, 1))
107
+ # [BS, BV]
108
+ b_v = tl.load(p_v, boundary_check=(0, 1))
109
+ # [BT, BS]
110
+ b_s = tl.dot(b_q, b_k)
111
+ b_s = tl.where(o_q[:, None] >= o_k[None, :], b_s, float('-inf'))
112
+
113
+ # [BT]
114
+ b_m, b_mp = tl.maximum(b_m, tl.max(b_s, 1)), b_m
115
+ b_r = exp(b_mp - b_m)
116
+ # [BT, BS]
117
+ b_p = exp(b_s - b_m[:, None])
118
+ # [BT]
119
+ b_acc = b_acc * b_r + tl.sum(b_p, 1)
120
+ # [BT, BV]
121
+ b_o = b_o * b_r[:, None] + tl.dot(b_p.to(b_q.dtype), b_v)
122
+
123
+ b_mp = b_m
124
+ b_o = b_o / b_acc[:, None]
125
+ b_m += log(b_acc)
126
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
127
+ tl.store(p_lse, b_m.to(p_lse.dtype.element_ty), boundary_check=(0,))
128
+
129
+
130
+ @triton.jit
131
+ def parallel_attn_bwd_kernel_preprocess(
132
+ o,
133
+ do,
134
+ delta,
135
+ B: tl.constexpr,
136
+ V: tl.constexpr
137
+ ):
138
+ i_n = tl.program_id(0)
139
+ o_d = tl.arange(0, B)
140
+ m_d = o_d < V
141
+
142
+ b_o = tl.load(o + i_n * V + o_d, mask=m_d, other=0)
143
+ b_do = tl.load(do + i_n * V + o_d, mask=m_d, other=0).to(tl.float32)
144
+ b_delta = tl.sum(b_o * b_do)
145
+
146
+ tl.store(delta + i_n, b_delta.to(delta.dtype.element_ty))
147
+
148
+
149
+ @triton.heuristics({
150
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
151
+ })
152
+ @triton.autotune(
153
+ configs=[
154
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
155
+ for num_warps in [1, 2, 4] + ([8] if check_shared_mem('hopper') else [])
156
+ for num_stages in [2, 3, 4, 5]
157
+ ],
158
+ key=['B', 'H', 'G', 'K', 'V', 'BK', 'BV'],
159
+ )
160
+ @triton.jit(do_not_specialize=['T'])
161
+ def parallel_attn_bwd_kernel_dq(
162
+ q,
163
+ k,
164
+ v,
165
+ lse,
166
+ delta,
167
+ do,
168
+ dq,
169
+ scale,
170
+ offsets,
171
+ indices,
172
+ T,
173
+ B: tl.constexpr,
174
+ H: tl.constexpr,
175
+ HQ: tl.constexpr,
176
+ G: tl.constexpr,
177
+ K: tl.constexpr,
178
+ V: tl.constexpr,
179
+ BT: tl.constexpr,
180
+ BS: tl.constexpr,
181
+ BK: tl.constexpr,
182
+ BV: tl.constexpr,
183
+ USE_OFFSETS: tl.constexpr
184
+ ):
185
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
186
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
187
+ i_h = i_hq // G
188
+
189
+ if USE_OFFSETS:
190
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
191
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
192
+ T = eos - bos
193
+ else:
194
+ i_n = i_b
195
+ bos, eos = i_n * T, i_n * T + T
196
+
197
+ p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
198
+ p_dq = tl.make_block_ptr(dq + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
199
+ p_do = tl.make_block_ptr(do + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
200
+ p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,))
201
+ p_delta = tl.make_block_ptr(delta + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,))
202
+
203
+ # [BT, BK]
204
+ b_q = tl.load(p_q, boundary_check=(0, 1))
205
+ b_q = (b_q * scale).to(b_q.dtype)
206
+ # [BT, BV]
207
+ b_do = tl.load(p_do, boundary_check=(0, 1))
208
+ # [BT]
209
+ b_lse = tl.load(p_lse, boundary_check=(0,))
210
+ b_delta = tl.load(p_delta, boundary_check=(0,))
211
+
212
+ # [BT, BK]
213
+ b_dq = tl.zeros([BT, BK], dtype=tl.float32)
214
+ for i_s in range(0, i_t * BT, BS):
215
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1))
216
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (V, T), (1, H*V), (i_v * BV, i_s), (BV, BS), (0, 1))
217
+ # [BK, BS]
218
+ b_k = tl.load(p_k, boundary_check=(0, 1))
219
+ # [BV, BS]
220
+ b_v = tl.load(p_v, boundary_check=(0, 1))
221
+
222
+ # [BT, BS]
223
+ b_s = tl.dot(b_q, b_k)
224
+ b_p = exp(b_s - b_lse[:, None])
225
+
226
+ # [BT, BV] @ [BV, BS] -> [BT, BS]
227
+ b_dp = tl.dot(b_do, b_v)
228
+ b_ds = b_p * (b_dp.to(tl.float32) - b_delta[:, None])
229
+ # [BT, BS] @ [BS, BK] -> [BT, BK]
230
+ b_dq += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_k))
231
+
232
+ # [BT]
233
+ o_q = i_t * BT + tl.arange(0, BT)
234
+ for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS):
235
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1))
236
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (V, T), (1, H*V), (i_v * BV, i_s), (BV, BS), (0, 1))
237
+ # [BS]
238
+ o_k = i_s + tl.arange(0, BS)
239
+ # [BK, BS]
240
+ b_k = tl.load(p_k, boundary_check=(0, 1))
241
+ # [BV, BS]
242
+ b_v = tl.load(p_v, boundary_check=(0, 1))
243
+
244
+ # [BT, BS]
245
+ b_s = tl.dot(b_q, b_k)
246
+ b_p = exp(b_s - b_lse[:, None])
247
+ b_p = tl.where(o_q[:, None] >= o_k[None, :], b_p, 0)
248
+
249
+ # [BT, BV] @ [BV, BS] -> [BT, BS]
250
+ b_dp = tl.dot(b_do, b_v)
251
+ b_ds = b_p * (b_dp.to(tl.float32) - b_delta[:, None])
252
+ # [BT, BS] @ [BS, BK] -> [BT, BK]
253
+ b_dq += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_k))
254
+
255
+ b_dq *= scale
256
+
257
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
258
+
259
+
260
+ @triton.heuristics({
261
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
262
+ })
263
+ @triton.autotune(
264
+ configs=[
265
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
266
+ for num_warps in [1, 2, 4] + ([8] if check_shared_mem('hopper') else [])
267
+ for num_stages in [2, 3, 4, 5]
268
+ ],
269
+ key=['B', 'H', 'G', 'K', 'V', 'BK', 'BV'],
270
+ )
271
+ @triton.jit(do_not_specialize=['T'])
272
+ def parallel_attn_bwd_kernel_dkv(
273
+ q,
274
+ k,
275
+ v,
276
+ lse,
277
+ delta,
278
+ do,
279
+ dk,
280
+ dv,
281
+ offsets,
282
+ indices,
283
+ scale,
284
+ T,
285
+ B: tl.constexpr,
286
+ H: tl.constexpr,
287
+ HQ: tl.constexpr,
288
+ G: tl.constexpr,
289
+ K: tl.constexpr,
290
+ V: tl.constexpr,
291
+ BT: tl.constexpr,
292
+ BS: tl.constexpr,
293
+ BK: tl.constexpr,
294
+ BV: tl.constexpr,
295
+ USE_OFFSETS: tl.constexpr
296
+ ):
297
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
298
+ i_b, i_hq = i_bh // HQ, i_bh % HQ
299
+ i_h = i_hq // G
300
+
301
+ if USE_OFFSETS:
302
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
303
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
304
+ T = eos - bos
305
+ else:
306
+ i_n = i_b
307
+ bos, eos = i_n * T, i_n * T + T
308
+
309
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
310
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
311
+ p_dk = tl.make_block_ptr(dk + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
312
+ p_dv = tl.make_block_ptr(dv + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
313
+
314
+ # [BT, BK]
315
+ b_k = tl.load(p_k, boundary_check=(0, 1))
316
+ b_dk = tl.zeros([BT, BK], dtype=tl.float32)
317
+ # [BT, BV]
318
+ b_v = tl.load(p_v, boundary_check=(0, 1))
319
+ b_dv = tl.zeros([BT, BV], dtype=tl.float32)
320
+
321
+ o_k = i_t * BT + tl.arange(0, BT)
322
+ for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS):
323
+ p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_s, 0), (BS, BK), (1, 0))
324
+ p_do = tl.make_block_ptr(do + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
325
+ p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
326
+ p_delta = tl.make_block_ptr(delta + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
327
+
328
+ # [BS]
329
+ o_q = i_s + tl.arange(0, BS)
330
+ # [BS, BK]
331
+ b_q = tl.load(p_q, boundary_check=(0, 1))
332
+ b_q = (b_q * scale).to(b_q.dtype)
333
+ # [BS, BV]
334
+ b_do = tl.load(p_do, boundary_check=(0, 1))
335
+ # [BS]
336
+ b_lse = tl.load(p_lse, boundary_check=(0,))
337
+ b_delta = tl.load(p_delta, boundary_check=(0,))
338
+ # [BT, BS]
339
+ b_s = tl.dot(b_k, tl.trans(b_q))
340
+ b_p = exp(b_s - b_lse[None, :])
341
+ b_p = tl.where(o_k[:, None] <= o_q[None, :], b_p, 0)
342
+ # [BT, BS] @ [BS, BV] -> [BT, BV]
343
+ b_dv += tl.dot(b_p.to(b_do.dtype), b_do)
344
+ # [BT, BV] @ [BV, BS] -> [BT, BS]
345
+ b_dp = tl.dot(b_v, tl.trans(b_do))
346
+ # [BT, BS]
347
+ b_ds = b_p * (b_dp - b_delta[None, :])
348
+ # [BT, BS] @ [BS, BK] -> [BT, BK]
349
+ b_dk += tl.dot(b_ds.to(b_q.dtype), b_q)
350
+
351
+ for i_s in range((i_t + 1) * BT, tl.cdiv(T, BS) * BS, BS):
352
+ p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_s, 0), (BS, BK), (1, 0))
353
+ p_do = tl.make_block_ptr(do + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
354
+ p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
355
+ p_delta = tl.make_block_ptr(delta + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,))
356
+
357
+ # [BS]
358
+ o_q = i_s + tl.arange(0, BS)
359
+ # [BS, BK]
360
+ b_q = tl.load(p_q, boundary_check=(0, 1))
361
+ b_q = (b_q * scale).to(b_q.dtype)
362
+ # [BS, BV]
363
+ b_do = tl.load(p_do, boundary_check=(0, 1))
364
+ # [BS]
365
+ b_lse = tl.load(p_lse, boundary_check=(0,))
366
+ b_delta = tl.load(p_delta, boundary_check=(0,))
367
+ # [BT, BS]
368
+ b_s = tl.dot(b_k, tl.trans(b_q))
369
+ b_p = exp(b_s - b_lse[None, :])
370
+ # [BT, BS] @ [BS, BV] -> [BT, BV]
371
+ b_dv += tl.dot(b_p.to(b_do.dtype), b_do)
372
+ # [BT, BV] @ [BV, BS] -> [BT, BS]
373
+ b_dp = tl.dot(b_v, tl.trans(b_do))
374
+ # [BT, BS]
375
+ b_ds = b_p * (b_dp - b_delta[None, :])
376
+ # [BT, BS] @ [BS, BK] -> [BT, BK]
377
+ b_dk += tl.dot(b_ds.to(b_q.dtype), b_q)
378
+
379
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
380
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
381
+
382
+
383
+ def parallel_attn_fwd(
384
+ q: torch.Tensor,
385
+ k: torch.Tensor,
386
+ v: torch.Tensor,
387
+ scale: float,
388
+ chunk_size: int = 128,
389
+ offsets: Optional[torch.LongTensor] = None,
390
+ indices: Optional[torch.LongTensor] = None,
391
+ ):
392
+ B, T, H, K, V = *k.shape, v.shape[-1]
393
+ HQ = q.shape[2]
394
+ G = HQ // H
395
+ BT = chunk_size
396
+ if check_shared_mem('hopper', q.device.index):
397
+ BS = min(64, max(16, triton.next_power_of_2(T)))
398
+ BK = min(256, max(16, triton.next_power_of_2(K)))
399
+ BV = min(256, max(16, triton.next_power_of_2(V)))
400
+ elif check_shared_mem('ampere', q.device.index):
401
+ BS = min(32, max(16, triton.next_power_of_2(T)))
402
+ BK = min(256, max(16, triton.next_power_of_2(K)))
403
+ BV = min(128, max(16, triton.next_power_of_2(V)))
404
+ else:
405
+ BS = min(32, max(16, triton.next_power_of_2(T)))
406
+ BK = min(256, max(16, triton.next_power_of_2(K)))
407
+ BV = min(64, max(16, triton.next_power_of_2(V)))
408
+ NK = triton.cdiv(K, BK)
409
+ NV = triton.cdiv(V, BV)
410
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
411
+ assert NK == 1, "The key dimension can not be larger than 256"
412
+
413
+ o = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device)
414
+ lse = torch.empty(B, T, HQ, dtype=torch.float, device=q.device)
415
+
416
+ grid = (NV, NT, B * HQ)
417
+ parallel_attn_fwd_kernel[grid](
418
+ q=q,
419
+ k=k,
420
+ v=v,
421
+ o=o,
422
+ lse=lse,
423
+ scale=scale,
424
+ offsets=offsets,
425
+ indices=indices,
426
+ B=B,
427
+ T=T,
428
+ H=H,
429
+ HQ=HQ,
430
+ G=G,
431
+ K=K,
432
+ V=V,
433
+ BT=BT,
434
+ BS=BS,
435
+ BK=BK,
436
+ BV=BV,
437
+ )
438
+ return o, lse
439
+
440
+
441
+ def parallel_attn_bwd_preprocess(
442
+ o: torch.Tensor,
443
+ do: torch.Tensor
444
+ ):
445
+ V = o.shape[-1]
446
+ delta = torch.empty_like(o[..., 0], dtype=torch.float32)
447
+ parallel_attn_bwd_kernel_preprocess[(delta.numel(),)](
448
+ o=o,
449
+ do=do,
450
+ delta=delta,
451
+ B=triton.next_power_of_2(V),
452
+ V=V,
453
+ )
454
+ return delta
455
+
456
+
457
+ def parallel_attn_bwd(
458
+ q: torch.Tensor,
459
+ k: torch.Tensor,
460
+ v: torch.Tensor,
461
+ o: torch.Tensor,
462
+ lse: torch.Tensor,
463
+ do: torch.Tensor,
464
+ scale: float = None,
465
+ chunk_size: int = 128,
466
+ offsets: Optional[torch.LongTensor] = None,
467
+ indices: Optional[torch.LongTensor] = None,
468
+ ):
469
+ B, T, H, K, V = *k.shape, v.shape[-1]
470
+ HQ = q.shape[2]
471
+ G = HQ // H
472
+ BT = chunk_size
473
+ BS = max(16, triton.next_power_of_2(T))
474
+ BS = min(32, BS) if check_shared_mem('ampere') else min(16, BS)
475
+ BK = max(16, triton.next_power_of_2(K))
476
+ BV = max(16, triton.next_power_of_2(V))
477
+ NV = triton.cdiv(V, BV)
478
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
479
+
480
+ delta = parallel_attn_bwd_preprocess(o, do)
481
+
482
+ dq = torch.empty(B, T, HQ, K, dtype=k.dtype if H == HQ else torch.float, device=q.device)
483
+ dk = torch.empty(B, T, HQ, K, dtype=k.dtype if H == HQ else torch.float, device=q.device)
484
+ dv = torch.empty(B, T, HQ, V, dtype=v.dtype if H == HQ else torch.float, device=q.device)
485
+ grid = (NV, NT, B * HQ)
486
+ parallel_attn_bwd_kernel_dq[grid](
487
+ q=q,
488
+ k=k,
489
+ v=v,
490
+ lse=lse,
491
+ delta=delta,
492
+ do=do,
493
+ dq=dq,
494
+ offsets=offsets,
495
+ indices=indices,
496
+ scale=scale,
497
+ T=T,
498
+ B=B,
499
+ H=H,
500
+ HQ=HQ,
501
+ G=G,
502
+ K=K,
503
+ V=V,
504
+ BT=BT,
505
+ BS=BS,
506
+ BK=BK,
507
+ BV=BV
508
+ )
509
+ parallel_attn_bwd_kernel_dkv[grid](
510
+ q=q,
511
+ k=k,
512
+ v=v,
513
+ lse=lse,
514
+ delta=delta,
515
+ do=do,
516
+ dk=dk,
517
+ dv=dv,
518
+ offsets=offsets,
519
+ indices=indices,
520
+ scale=scale,
521
+ T=T,
522
+ B=B,
523
+ H=H,
524
+ HQ=HQ,
525
+ G=G,
526
+ K=K,
527
+ V=V,
528
+ BT=BT,
529
+ BS=BS,
530
+ BK=BK,
531
+ BV=BV
532
+ )
533
+ dk = reduce(dk, 'b t (h g) k -> b t h k', g=G, reduction='sum')
534
+ dv = reduce(dv, 'b t (h g) v -> b t h v', g=G, reduction='sum')
535
+ return dq, dk, dv
536
+
537
+
538
+ @torch.compile
539
+ class ParallelAttentionFunction(torch.autograd.Function):
540
+
541
+ @staticmethod
542
+ @contiguous
543
+ @autocast_custom_fwd
544
+ def forward(ctx, q, k, v, scale, offsets):
545
+ ctx.dtype = q.dtype
546
+
547
+ chunk_size = min(128, max(16, triton.next_power_of_2(q.shape[1])))
548
+ # 2-d indices denoting the offsets of chunks in each sequence
549
+ # for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64,
550
+ # then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be
551
+ # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
552
+ indices = prepare_chunk_indices(offsets, chunk_size) if offsets is not None else None
553
+
554
+ o, lse = parallel_attn_fwd(
555
+ q=q,
556
+ k=k,
557
+ v=v,
558
+ scale=scale,
559
+ chunk_size=chunk_size,
560
+ offsets=offsets,
561
+ indices=indices
562
+ )
563
+ ctx.save_for_backward(q, k, v, o, lse)
564
+ ctx.chunk_size = chunk_size
565
+ ctx.offsets = offsets
566
+ ctx.indices = indices
567
+ ctx.scale = scale
568
+ return o.to(q.dtype)
569
+
570
+ @staticmethod
571
+ @contiguous
572
+ @autocast_custom_bwd
573
+ def backward(ctx, do):
574
+ q, k, v, o, lse = ctx.saved_tensors
575
+ dq, dk, dv = parallel_attn_bwd(
576
+ q=q,
577
+ k=k,
578
+ v=v,
579
+ o=o,
580
+ lse=lse,
581
+ do=do,
582
+ scale=ctx.scale,
583
+ chunk_size=ctx.chunk_size,
584
+ offsets=ctx.offsets,
585
+ indices=ctx.indices
586
+ )
587
+ return dq.to(q), dk.to(k), dv.to(v), None, None, None, None, None, None, None, None
588
+
589
+
590
+ def parallel_attn(
591
+ q: torch.Tensor,
592
+ k: torch.Tensor,
593
+ v: torch.Tensor,
594
+ scale: Optional[float] = None,
595
+ cu_seqlens: Optional[torch.LongTensor] = None,
596
+ head_first: bool = False
597
+ ) -> torch.Tensor:
598
+ r"""
599
+ Args:
600
+ q (torch.Tensor):
601
+ queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`.
602
+ k (torch.Tensor):
603
+ keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
604
+ GQA will be applied if HQ is divisible by H.
605
+ v (torch.Tensor):
606
+ values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
607
+ scale (Optional[int]):
608
+ Scale factor for attention scores.
609
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
610
+ cu_seqlens (torch.LongTensor):
611
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
612
+ consistent with the FlashAttention API.
613
+ head_first (Optional[bool]):
614
+ Whether the inputs are in the head-first format. Default: `False`.
615
+
616
+ Returns:
617
+ o (torch.Tensor):
618
+ Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`.
619
+ """
620
+ if scale is None:
621
+ scale = k.shape[-1] ** -0.5
622
+ if cu_seqlens is not None:
623
+ assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided"
624
+ if head_first:
625
+ q, k, v = map(lambda x: rearrange(x, 'b h t d -> b t h d'), (q, k, v))
626
+ o = ParallelAttentionFunction.apply(q, k, v, scale, cu_seqlens)
627
+ if head_first:
628
+ o = rearrange(o, 'b t h d -> b h t d')
629
+ return o
fla/ops/based/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (286 Bytes). View file
 
fla/ops/common/__pycache__/chunk_delta_h.cpython-312.pyc ADDED
Binary file (23.9 kB). View file
 
fla/ops/common/chunk_delta_h.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.common.utils import prepare_chunk_offsets
11
+ from fla.ops.utils.op import exp
12
+ from fla.utils import check_shared_mem, is_nvidia_hopper, use_cuda_graph
13
+
14
+ NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8, 16]
15
+
16
+
17
+ @triton.heuristics({
18
+ 'USE_G': lambda args: args['g'] is not None,
19
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
20
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
21
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
22
+ })
23
+ @triton.autotune(
24
+ configs=[
25
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
26
+ for num_warps in NUM_WARPS
27
+ for num_stages in [2, 3, 4]
28
+ ],
29
+ key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'USE_G'],
30
+ use_cuda_graph=use_cuda_graph,
31
+ )
32
+ @triton.jit(do_not_specialize=['T'])
33
+ def chunk_gated_delta_rule_fwd_kernel_h(
34
+ k,
35
+ v,
36
+ d,
37
+ v_new,
38
+ g,
39
+ h,
40
+ h0,
41
+ ht,
42
+ offsets,
43
+ chunk_offsets,
44
+ T,
45
+ H: tl.constexpr,
46
+ K: tl.constexpr,
47
+ V: tl.constexpr,
48
+ BT: tl.constexpr,
49
+ BC: tl.constexpr,
50
+ BK: tl.constexpr,
51
+ BV: tl.constexpr,
52
+ NT: tl.constexpr,
53
+ USE_G: tl.constexpr,
54
+ USE_INITIAL_STATE: tl.constexpr,
55
+ STORE_FINAL_STATE: tl.constexpr,
56
+ USE_OFFSETS: tl.constexpr,
57
+ HEAD_FIRST: tl.constexpr,
58
+ ):
59
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
60
+ i_n, i_h = i_nh // H, i_nh % H
61
+ if USE_OFFSETS:
62
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
63
+ T = eos - bos
64
+ NT = tl.cdiv(T, BT)
65
+ boh = tl.load(chunk_offsets + i_n).to(tl.int32)
66
+ else:
67
+ bos, eos = i_n * T, i_n * T + T
68
+ NT = tl.cdiv(T, BT)
69
+ boh = i_n * NT
70
+
71
+ # [BK, BV]
72
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
73
+ if USE_INITIAL_STATE:
74
+ p_h0 = tl.make_block_ptr(h0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
75
+ b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)
76
+
77
+ for i_t in range(NT):
78
+ if HEAD_FIRST:
79
+ p_h = tl.make_block_ptr(h + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
80
+ else:
81
+ p_h = tl.make_block_ptr(h + ((boh + i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
82
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
83
+ b_hc = tl.zeros([BK, BV], dtype=tl.float32)
84
+ if USE_G:
85
+ last_idx = min((i_t + 1) * BT, T) - 1
86
+ if HEAD_FIRST:
87
+ b_g_last = tl.load(g + i_nh * T + last_idx)
88
+ else:
89
+ b_g_last = tl.load(g + bos * H + last_idx * H + i_h)
90
+ else:
91
+ b_g_last = None
92
+ last_idx = None
93
+ # since we need to make all DK in the SRAM. we face serve SRAM memory burden. By subchunking we allievate such burden
94
+ for i_c in range(tl.cdiv(min(BT, T - i_t * BT), BC)):
95
+ if HEAD_FIRST:
96
+ p_k = tl.make_block_ptr(k + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
97
+ p_d = tl.make_block_ptr(d + i_nh * T*K, (T, K), (K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
98
+ p_v = tl.make_block_ptr(v + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
99
+ p_v_new = tl.make_block_ptr(v_new+i_nh*T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
100
+ p_g = tl.make_block_ptr(g + i_nh * T, (T,), (1,), (i_t * BT + i_c * BC,), (BC,), (0,)) if USE_G else None
101
+ else:
102
+ p_k = tl.make_block_ptr(k+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
103
+ p_d = tl.make_block_ptr(d+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
104
+ p_v = tl.make_block_ptr(v+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
105
+ p_v_new = tl.make_block_ptr(v_new+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT+i_c*BC, i_v * BV), (BC, BV), (1, 0))
106
+ p_g = tl.make_block_ptr(g+bos*H+i_h, (T,), (H,), (i_t*BT+i_c*BC, ), (BC,), (0,)) if USE_G else None
107
+ b_g = tl.load(p_g, boundary_check=(0, )) if USE_G else None
108
+ # [BK, BC]
109
+ b_k = tl.load(p_k, boundary_check=(0, 1))
110
+ b_k = (b_k * exp(b_g_last - b_g)[None, :]).to(b_k.dtype) if USE_G else b_k
111
+ # [BC, BK]
112
+ b_d = tl.load(p_d, boundary_check=(0, 1))
113
+ b_d = (b_d * exp(b_g)[:, None]).to(b_d.dtype) if USE_G else b_d
114
+ # [BC, BV]
115
+ b_v = tl.load(p_v, boundary_check=(0, 1))
116
+ b_v2 = b_v - tl.dot(b_d, b_h.to(b_d.dtype))
117
+ # [BK, BV]
118
+ tl.store(p_v_new, b_v2.to(p_v_new.dtype.element_ty), boundary_check=(0, 1))
119
+ b_hc += tl.dot(b_k, b_v2.to(b_k.dtype), allow_tf32=False)
120
+ b_h *= exp(b_g_last) if USE_G else 1
121
+ b_h += b_hc
122
+
123
+ if STORE_FINAL_STATE:
124
+ p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
125
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
126
+
127
+
128
+ @triton.heuristics({
129
+ 'USE_G': lambda args: args['g'] is not None,
130
+ 'USE_INITIAL_STATE': lambda args: args['dh0'] is not None,
131
+ 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None,
132
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
133
+ })
134
+ @triton.autotune(
135
+ configs=[
136
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
137
+ for num_warps in NUM_WARPS
138
+ for num_stages in [2, 3, 4]
139
+ ],
140
+ key=['BT', 'BK', 'BV', 'USE_G'],
141
+ use_cuda_graph=use_cuda_graph,
142
+ )
143
+ @triton.jit(do_not_specialize=['T'])
144
+ def chunk_gated_delta_rule_bwd_kernel_dhu(
145
+ q,
146
+ k,
147
+ d,
148
+ g,
149
+ dht,
150
+ dh0,
151
+ do,
152
+ dh,
153
+ dv,
154
+ dv2,
155
+ offsets,
156
+ chunk_offsets,
157
+ scale,
158
+ T,
159
+ H: tl.constexpr,
160
+ K: tl.constexpr,
161
+ V: tl.constexpr,
162
+ BT: tl.constexpr,
163
+ BC: tl.constexpr,
164
+ BK: tl.constexpr,
165
+ BV: tl.constexpr,
166
+ USE_G: tl.constexpr,
167
+ USE_INITIAL_STATE: tl.constexpr,
168
+ USE_FINAL_STATE_GRADIENT: tl.constexpr,
169
+ USE_OFFSETS: tl.constexpr,
170
+ HEAD_FIRST: tl.constexpr
171
+ ):
172
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
173
+ i_n, i_h = i_nh // H, i_nh % H
174
+ if USE_OFFSETS:
175
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
176
+ T = eos - bos
177
+ NT = tl.cdiv(T, BT)
178
+ boh = tl.load(chunk_offsets + i_n).to(tl.int32)
179
+ else:
180
+ bos, eos = i_n * T, i_n * T + T
181
+ NT = tl.cdiv(T, BT)
182
+ boh = i_n * NT
183
+
184
+ # [BK, BV]
185
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
186
+ if USE_FINAL_STATE_GRADIENT:
187
+ p_dht = tl.make_block_ptr(dht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
188
+ b_dh += tl.load(p_dht, boundary_check=(0, 1))
189
+
190
+ for i_t in range(NT - 1, -1, -1):
191
+ if HEAD_FIRST:
192
+ p_dh = tl.make_block_ptr(dh + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
193
+ else:
194
+ p_dh = tl.make_block_ptr(dh + ((boh+i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
195
+ tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
196
+ b_dh_tmp = tl.zeros([BK, BV], dtype=tl.float32)
197
+ if USE_G:
198
+ last_idx = min((i_t + 1) * BT, T) - 1
199
+ if HEAD_FIRST:
200
+ bg_last = tl.load(g + i_nh * T + last_idx)
201
+ else:
202
+ bg_last = tl.load(g + (bos + last_idx) * H + i_h)
203
+ else:
204
+ bg_last = None
205
+ last_idx = None
206
+ for i_c in range(tl.cdiv(BT, BC) - 1, -1, -1):
207
+ if HEAD_FIRST:
208
+ p_q = tl.make_block_ptr(q + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
209
+ p_k = tl.make_block_ptr(k + i_nh * T*K, (T, K), (K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
210
+ p_d = tl.make_block_ptr(d + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
211
+ p_dv = tl.make_block_ptr(dv + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
212
+ p_do = tl.make_block_ptr(do + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
213
+ p_g = tl.make_block_ptr(g + i_nh * T, (T,), (1,), (i_t * BT + i_c * BC,), (BC,), (0,)) if USE_G else None
214
+ p_dv2 = tl.make_block_ptr(dv2 + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
215
+ else:
216
+ p_q = tl.make_block_ptr(q+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
217
+ p_k = tl.make_block_ptr(k+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
218
+ p_d = tl.make_block_ptr(d+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
219
+ p_dv = tl.make_block_ptr(dv+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
220
+ p_do = tl.make_block_ptr(do+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
221
+ p_g = tl.make_block_ptr(g+bos*H+i_h, (T,), (H,), (i_t*BT + i_c * BC,), (BC,), (0,)) if USE_G else None
222
+ p_dv2 = tl.make_block_ptr(dv2+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
223
+ b_g = tl.load(p_g, boundary_check=(0,)) if USE_G else None
224
+ # [BK, BT]
225
+ b_q = tl.load(p_q, boundary_check=(0, 1))
226
+ b_q = (b_q * scale * exp(b_g)[None, :]).to(b_q.dtype) if USE_G else (b_q * scale).to(b_q.dtype)
227
+ # [BT, BK]
228
+ b_k = tl.load(p_k, boundary_check=(0, 1))
229
+ b_d = tl.load(p_d, boundary_check=(0, 1))
230
+ b_k = (b_k * exp(bg_last - b_g)[:, None]).to(b_k.dtype) if USE_G else b_k
231
+ b_d = (b_d * exp(b_g)[None, :]).to(b_d.dtype) if USE_G else b_d
232
+ # [BT, V]
233
+ b_do = tl.load(p_do, boundary_check=(0, 1))
234
+ b_dv = tl.load(p_dv, boundary_check=(0, 1))
235
+ b_dv2 = b_dv + tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False)
236
+ tl.store(p_dv2, b_dv2.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
237
+ # [BK, BV]
238
+ b_dh_tmp += tl.dot(b_q, b_do.to(b_q.dtype), allow_tf32=False)
239
+ b_dh_tmp -= tl.dot(b_d, b_dv2.to(b_q.dtype), allow_tf32=False)
240
+ b_dh *= exp(bg_last) if USE_G else 1
241
+ b_dh += b_dh_tmp
242
+
243
+ if USE_INITIAL_STATE:
244
+ p_dh0 = tl.make_block_ptr(dh0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
245
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1))
246
+
247
+
248
+ def chunk_gated_delta_rule_fwd_h(
249
+ k: torch.Tensor,
250
+ w: torch.Tensor,
251
+ u: torch.Tensor,
252
+ g: Optional[torch.Tensor] = None,
253
+ initial_state: Optional[torch.Tensor] = None,
254
+ output_final_state: bool = False,
255
+ offsets: Optional[torch.LongTensor] = None,
256
+ indices: Optional[torch.LongTensor] = None,
257
+ head_first: bool = True,
258
+ chunk_size: int = 64
259
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
260
+ if head_first:
261
+ B, H, T, K, V = *k.shape, u.shape[-1]
262
+ else:
263
+ B, T, H, K, V = *k.shape, u.shape[-1]
264
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
265
+ # N: the actual number of sequences in the batch with either equal or variable lengths
266
+ if offsets is None:
267
+ N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
268
+ else:
269
+ N, NT, chunk_offsets = len(offsets) - 1, len(indices), prepare_chunk_offsets(offsets, BT)
270
+ BK = triton.next_power_of_2(K)
271
+ assert BK <= 256, "current kernel does not support head dimension larger than 256."
272
+ # H100 can have larger block size
273
+ if check_shared_mem('hopper', k.device.index):
274
+ BV = 64
275
+ BC = 64 if K <= 128 else 32
276
+ # A100
277
+ elif check_shared_mem('ampere', k.device.index):
278
+ BV = 32
279
+ BC = 64
280
+ else:
281
+ BV = 32
282
+ BC = 32 if K <= 128 else 16
283
+ BC = min(BT, BC)
284
+ NK = triton.cdiv(K, BK)
285
+ NV = triton.cdiv(V, BV)
286
+ assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization'
287
+
288
+ if head_first:
289
+ h = k.new_empty(B, H, NT, K, V)
290
+ else:
291
+ h = k.new_empty(B, NT, H, K, V)
292
+ final_state = k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None
293
+
294
+ v_new = torch.empty_like(u)
295
+ grid = (NK, NV, N * H)
296
+
297
+ chunk_gated_delta_rule_fwd_kernel_h[grid](
298
+ k=k,
299
+ v=u,
300
+ d=w,
301
+ v_new=v_new,
302
+ g=g,
303
+ h=h,
304
+ h0=initial_state,
305
+ ht=final_state,
306
+ offsets=offsets,
307
+ chunk_offsets=chunk_offsets,
308
+ T=T,
309
+ H=H,
310
+ K=K,
311
+ V=V,
312
+ BT=BT,
313
+ BC=BC,
314
+ BK=BK,
315
+ BV=BV,
316
+ NT=NT,
317
+ HEAD_FIRST=head_first
318
+ )
319
+ return h, v_new, final_state
320
+
321
+
322
+ def chunk_gated_delta_rule_bwd_dhu(
323
+ q: torch.Tensor,
324
+ k: torch.Tensor,
325
+ w: torch.Tensor,
326
+ g: torch.Tensor,
327
+ h0: torch.Tensor,
328
+ dht: Optional[torch.Tensor],
329
+ do: torch.Tensor,
330
+ dv: torch.Tensor,
331
+ scale: float,
332
+ offsets: Optional[torch.LongTensor] = None,
333
+ indices: Optional[torch.LongTensor] = None,
334
+ head_first: bool = True,
335
+ chunk_size: int = 64
336
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
337
+ if head_first:
338
+ B, H, T, K, V = *q.shape, do.shape[-1]
339
+ else:
340
+ B, T, H, K, V = *q.shape, do.shape[-1]
341
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
342
+ # N: the actual number of sequences in the batch with either equal or variable lengths
343
+ if offsets is None:
344
+ N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
345
+ else:
346
+ N, NT, chunk_offsets = len(offsets) - 1, len(indices), prepare_chunk_offsets(offsets, BT)
347
+
348
+ BK = triton.next_power_of_2(K)
349
+ assert BK <= 256, "current kernel does not support head dimension being larger than 256."
350
+
351
+ # H100
352
+ if check_shared_mem('hopper', q.device.index):
353
+ BV = 64
354
+ BC = 64 if K <= 128 else 32
355
+ # A100
356
+ elif check_shared_mem('ampere', q.device.index):
357
+ BV = 32
358
+ BC = 64 if K <= 128 else 32
359
+ else:
360
+ BV = 32 if K <= 128 else 16
361
+ BC = 16
362
+
363
+ BC = min(BT, BC)
364
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
365
+ assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization'
366
+
367
+ if head_first:
368
+ dh = q.new_empty(B, H, NT, K, V)
369
+ else:
370
+ dh = q.new_empty(B, NT, H, K, V)
371
+ dh0 = torch.empty_like(h0, dtype=torch.float32) if h0 is not None else None
372
+ dv2 = torch.empty_like(dv)
373
+
374
+ grid = (NK, NV, N * H)
375
+ chunk_gated_delta_rule_bwd_kernel_dhu[grid](
376
+ q=q,
377
+ k=k,
378
+ d=w,
379
+ g=g,
380
+ dht=dht,
381
+ dh0=dh0,
382
+ do=do,
383
+ dh=dh,
384
+ dv=dv,
385
+ dv2=dv2,
386
+ offsets=offsets,
387
+ chunk_offsets=chunk_offsets,
388
+ scale=scale,
389
+ T=T,
390
+ H=H,
391
+ K=K,
392
+ V=V,
393
+ BT=BT,
394
+ BC=BC,
395
+ BK=BK,
396
+ BV=BV,
397
+ HEAD_FIRST=head_first
398
+ )
399
+ return dh, dh0, dv2
fla/ops/common/chunk_h_parallel.py ADDED
@@ -0,0 +1,650 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ """
5
+ Fully parallelized state passing.
6
+ """
7
+
8
+ from typing import Optional, Tuple
9
+
10
+ import torch
11
+ import triton
12
+ import triton.language as tl
13
+
14
+ from fla.ops.utils.op import exp
15
+
16
+
17
+ @triton.heuristics({
18
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
19
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
20
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
21
+ })
22
+ @triton.autotune(
23
+ configs=[
24
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
25
+ for BK in [32, 64, 128]
26
+ for BV in [32, 64, 128]
27
+ for num_warps in [2, 4, 8]
28
+ for num_stages in [2, 3, 4]
29
+ ],
30
+ key=['BT', 'USE_G', 'USE_GK', 'USE_GV']
31
+ )
32
+ @triton.jit(do_not_specialize=['T'])
33
+ def chunk_fwd_kernel_h_parallel(
34
+ k,
35
+ v,
36
+ h,
37
+ g,
38
+ gk,
39
+ gv,
40
+ h0,
41
+ ht,
42
+ offsets,
43
+ indices,
44
+ T,
45
+ H: tl.constexpr,
46
+ K: tl.constexpr,
47
+ V: tl.constexpr,
48
+ BT: tl.constexpr,
49
+ BK: tl.constexpr,
50
+ BV: tl.constexpr,
51
+ USE_G: tl.constexpr,
52
+ USE_GK: tl.constexpr,
53
+ USE_GV: tl.constexpr,
54
+ USE_INITIAL_STATE: tl.constexpr,
55
+ STORE_FINAL_STATE: tl.constexpr,
56
+ USE_OFFSETS: tl.constexpr,
57
+ HEAD_FIRST: tl.constexpr
58
+ ):
59
+ i_kv, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
60
+
61
+ NV = tl.cdiv(V, BV)
62
+ # i_b: batch index
63
+ # i_h: head index
64
+ # i_n: sequence index
65
+ # i_t: chunk index within current sequence
66
+ # i_tg: (global) chunk index across all sequences
67
+ i_k, i_v = i_kv // NV, i_kv % NV
68
+ i_b, i_h = i_bh // H, i_bh % H
69
+ if USE_OFFSETS:
70
+ i_tg = i_t
71
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
72
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
73
+ T = eos - bos
74
+ NT = tl.cdiv(T, BT)
75
+ else:
76
+ bos, eos = i_b * T, i_b * T + T
77
+ NT = tl.cdiv(T, BT)
78
+ i_n, i_tg = i_b, i_b * NT + i_t
79
+ i_nh = i_n * H + i_h
80
+
81
+ if HEAD_FIRST:
82
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
83
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
84
+ p_h = tl.make_block_ptr(h + (i_bh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
85
+ else:
86
+ p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
87
+ p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
88
+ p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
89
+
90
+ if i_t == 0:
91
+ if USE_INITIAL_STATE:
92
+ p_h0 = tl.make_block_ptr(h0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
93
+ b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)
94
+ else:
95
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
96
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
97
+
98
+ # [BK, BT]
99
+ b_k = tl.load(p_k, boundary_check=(0, 1))
100
+ # [BT, BV]
101
+ b_v = tl.load(p_v, boundary_check=(0, 1))
102
+
103
+ last_idx = min(i_t * BT + BT, T) - 1
104
+ # scalar decay
105
+ if USE_G:
106
+ if HEAD_FIRST:
107
+ b_g_last = tl.load(g + i_bh * T + last_idx)
108
+ p_g = g + i_bh * T + i_t * BT + tl.arange(0, BT)
109
+ p_g = tl.max_contiguous(tl.multiple_of(p_g, BT), BT)
110
+ else:
111
+ b_g_last = tl.load(g + bos * H + last_idx * H + i_h)
112
+ p_g = g + bos*H + (i_t * BT + tl.arange(0, BT)) * H + i_h
113
+ b_g = tl.load(p_g, mask=(i_t * BT + tl.arange(0, BT) < T), other=0.)
114
+ b_v = (b_v * exp(b_g_last - b_g)[:, None]).to(b_v.dtype)
115
+
116
+ # vector decay, h = Diag(gk) @ h
117
+ if USE_GK:
118
+ if HEAD_FIRST:
119
+ p_gk = tl.make_block_ptr(gk + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
120
+ p_gk_last = gk + i_bh * T*K + last_idx * K + i_k * BK + tl.arange(0, BK)
121
+ p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK)
122
+ else:
123
+ p_gk = tl.make_block_ptr(gk + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
124
+ p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
125
+
126
+ b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.)
127
+
128
+ b_gk = tl.load(p_gk, boundary_check=(0, 1))
129
+ b_k = (b_k * exp(b_gk_last[:, None] - b_gk)).to(b_k.dtype)
130
+
131
+ # vector decay, h = h @ Diag(gv)
132
+ if USE_GV:
133
+ if HEAD_FIRST:
134
+ p_gv = tl.make_block_ptr(gv + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
135
+ p_gv_last = gv + i_bh * T*V + last_idx * V + i_v * BV + tl.arange(0, BV)
136
+ p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV)
137
+ else:
138
+ p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
139
+ p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
140
+
141
+ b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.)
142
+
143
+ b_gv = tl.load(p_gv, boundary_check=(0, 1))
144
+ b_v = (b_v * exp(b_gv_last[None, :] - b_gv)).to(b_v.dtype)
145
+
146
+ b_h = tl.dot(b_k, b_v)
147
+ if i_t < NT - 1:
148
+ if HEAD_FIRST:
149
+ p_h = tl.make_block_ptr(h + (i_bh * NT + i_t + 1) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
150
+ else:
151
+ p_h = tl.make_block_ptr(h + ((i_tg + 1) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
152
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
153
+ elif STORE_FINAL_STATE:
154
+ p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
155
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
156
+
157
+
158
+ @triton.heuristics({
159
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
160
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
161
+ })
162
+ @triton.autotune(
163
+ configs=[
164
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
165
+ for BK in [32, 64, 128]
166
+ for BV in [32, 64, 128]
167
+ for num_warps in [2, 4, 8, 16]
168
+ for num_stages in [2, 3]
169
+ ],
170
+ key=['BT', 'USE_G', 'USE_GK', 'USE_GV']
171
+ )
172
+ @triton.jit(do_not_specialize=['T'])
173
+ def chunk_fwd_kernel_h_reduction(
174
+ h,
175
+ g,
176
+ gk,
177
+ gv,
178
+ kvt,
179
+ ht,
180
+ offsets,
181
+ chunk_offsets,
182
+ T,
183
+ H: tl.constexpr,
184
+ K: tl.constexpr,
185
+ V: tl.constexpr,
186
+ BT: tl.constexpr,
187
+ BK: tl.constexpr,
188
+ BV: tl.constexpr,
189
+ USE_G: tl.constexpr,
190
+ USE_GK: tl.constexpr,
191
+ USE_GV: tl.constexpr,
192
+ STORE_FINAL_STATE: tl.constexpr,
193
+ USE_OFFSETS: tl.constexpr,
194
+ HEAD_FIRST: tl.constexpr
195
+ ):
196
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
197
+ i_n, i_h = i_nh // H, i_nh % H
198
+ if USE_OFFSETS:
199
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
200
+ T = eos - bos
201
+ NT = tl.cdiv(T, BT)
202
+ boh = tl.load(chunk_offsets + i_n).to(tl.int32)
203
+ else:
204
+ bos, eos = i_n * T, i_n * T + T
205
+ NT = tl.cdiv(T, BT)
206
+ boh = i_n * NT
207
+
208
+ # [BK, BV]
209
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
210
+ for i_t in range(NT):
211
+ if HEAD_FIRST:
212
+ p_h = tl.make_block_ptr(h + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
213
+ else:
214
+ p_h = tl.make_block_ptr(h + ((boh + i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
215
+ b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)
216
+ if i_t > 0:
217
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
218
+
219
+ last_idx = min(i_t * BT + BT, T) - 1
220
+ # scalar decay
221
+ if USE_G:
222
+ if HEAD_FIRST:
223
+ b_g_last = tl.load(g + i_nh * T + last_idx)
224
+ else:
225
+ b_g_last = tl.load(g + bos * H + last_idx * H + i_h)
226
+ b_h *= exp(b_g_last)
227
+
228
+ # vector decay, h = Diag(gk) @ h
229
+ if USE_GK:
230
+ if HEAD_FIRST:
231
+ p_gk_last = gk + i_nh * T*K + last_idx * K + i_k * BK + tl.arange(0, BK)
232
+ p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK)
233
+ else:
234
+ p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
235
+
236
+ b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.)
237
+ b_h *= exp(b_gk_last)[:, None]
238
+
239
+ # vector decay, h = h @ Diag(gv)
240
+ if USE_GV:
241
+ if HEAD_FIRST:
242
+ p_gv_last = gv + i_nh * T*V + last_idx * V + i_v * BV + tl.arange(0, BV)
243
+ p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV)
244
+ else:
245
+ p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
246
+
247
+ b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.)
248
+ b_h *= exp(b_gv_last)[None, :]
249
+
250
+ if STORE_FINAL_STATE:
251
+ p_kvt = tl.make_block_ptr(kvt + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
252
+ p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
253
+ b_h += tl.load(p_kvt, boundary_check=(0, 1)).to(tl.float32)
254
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
255
+
256
+
257
+ @triton.heuristics({
258
+ 'STORE_INITIAL_STATE_GRADIENT': lambda args: args['dh0'] is not None,
259
+ 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None,
260
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
261
+ })
262
+ @triton.autotune(
263
+ configs=[
264
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
265
+ for BK in [32, 64, 128]
266
+ for BV in [32, 64, 128]
267
+ for num_warps in [2, 4, 8]
268
+ for num_stages in [2, 3, 4]
269
+ ],
270
+ key=['BT', 'USE_G', 'USE_GK', 'USE_GV']
271
+ )
272
+ @triton.jit(do_not_specialize=['T'])
273
+ def chunk_bwd_kernel_dh_parallel(
274
+ q,
275
+ g,
276
+ gk,
277
+ gv,
278
+ do,
279
+ dh,
280
+ dht,
281
+ dh0,
282
+ offsets,
283
+ indices,
284
+ scale,
285
+ T,
286
+ HQ: tl.constexpr,
287
+ H: tl.constexpr,
288
+ K: tl.constexpr,
289
+ V: tl.constexpr,
290
+ BT: tl.constexpr,
291
+ BK: tl.constexpr,
292
+ BV: tl.constexpr,
293
+ NG: tl.constexpr,
294
+ USE_G: tl.constexpr,
295
+ USE_GK: tl.constexpr,
296
+ USE_GV: tl.constexpr,
297
+ STORE_INITIAL_STATE_GRADIENT: tl.constexpr,
298
+ USE_FINAL_STATE_GRADIENT: tl.constexpr,
299
+ USE_OFFSETS: tl.constexpr,
300
+ HEAD_FIRST: tl.constexpr
301
+ ):
302
+ i_kv, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
303
+
304
+ NV = tl.cdiv(V, BV)
305
+ i_k, i_v = i_kv // NV, i_kv % NV
306
+ i_b, i_hq, i_bg = i_bh // HQ, i_bh % HQ, i_bh // NG
307
+ i_h = i_hq // NG
308
+ if USE_OFFSETS:
309
+ i_tg = i_t
310
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
311
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
312
+ T = eos - bos
313
+ NT = tl.cdiv(T, BT)
314
+ else:
315
+ bos, eos = i_b * T, i_b * T + T
316
+ NT = tl.cdiv(T, BT)
317
+ i_n, i_tg = i_b, i_b * NT + i_t
318
+ i_nh = i_n * HQ + i_hq
319
+
320
+ if HEAD_FIRST:
321
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
322
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
323
+ p_dh = tl.make_block_ptr(dh + (i_bh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
324
+ else:
325
+ p_q = tl.make_block_ptr(q + (bos*HQ + i_hq) * K, (K, T), (1, HQ*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
326
+ p_do = tl.make_block_ptr(do + (bos*HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
327
+ p_dh = tl.make_block_ptr(dh + (i_tg * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
328
+
329
+ if i_t == NT - 1:
330
+ if USE_FINAL_STATE_GRADIENT:
331
+ p_dht = tl.make_block_ptr(dht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
332
+ b_dh = tl.load(p_dht, boundary_check=(0, 1)).to(tl.float32)
333
+ else:
334
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
335
+ tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
336
+
337
+ # [BK, BT]
338
+ b_q = tl.load(p_q, boundary_check=(0, 1))
339
+ b_q = (b_q * scale).to(b_q.dtype)
340
+ # [BT, BV]
341
+ b_do = tl.load(p_do, boundary_check=(0, 1))
342
+
343
+ if USE_G:
344
+ if HEAD_FIRST:
345
+ p_g = g + i_bg * T + i_t * BT + tl.arange(0, BT)
346
+ p_g = tl.max_contiguous(tl.multiple_of(p_g, BT), BT)
347
+ else:
348
+ p_g = g + (bos + i_t * BT + tl.arange(0, BT)) * H + i_h
349
+ b_g = tl.load(p_g, mask=(i_t * BT + tl.arange(0, BT) < T), other=0.)
350
+ b_q = (b_q * exp(b_g)[None, :]).to(b_q.dtype)
351
+
352
+ if USE_GK:
353
+ if HEAD_FIRST:
354
+ p_gk = tl.make_block_ptr(gk + i_bg * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
355
+ else:
356
+ p_gk = tl.make_block_ptr(gk + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
357
+ b_gk = tl.load(p_gk, boundary_check=(0, 1))
358
+ b_q = (b_q * exp(b_gk)).to(b_q.dtype)
359
+
360
+ if USE_GV:
361
+ if HEAD_FIRST:
362
+ p_gv = tl.make_block_ptr(gv + i_bg * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
363
+ else:
364
+ p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
365
+ b_gv = tl.load(p_gv, boundary_check=(0, 1))
366
+ b_do = (b_do * exp(b_gv)).to(b_do.dtype)
367
+
368
+ b_dh = tl.dot(b_q, b_do)
369
+ if i_t > 0:
370
+ if HEAD_FIRST:
371
+ p_dh = tl.make_block_ptr(dh + (i_bh * NT + i_t - 1) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
372
+ else:
373
+ p_dh = tl.make_block_ptr(dh + ((i_tg - 1) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
374
+ tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
375
+ elif STORE_INITIAL_STATE_GRADIENT:
376
+ p_dh0 = tl.make_block_ptr(dh0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
377
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1))
378
+
379
+
380
+ @triton.heuristics({
381
+ 'STORE_INITIAL_STATE_GRADIENT': lambda args: args['dh0'] is not None,
382
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
383
+ })
384
+ @triton.autotune(
385
+ configs=[
386
+ triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages)
387
+ for BK in [32, 64, 128]
388
+ for BV in [32, 64, 128]
389
+ for num_warps in [2, 4, 8, 16]
390
+ for num_stages in [2, 3]
391
+ ],
392
+ key=['BT', 'USE_G', 'USE_GK', 'USE_GV']
393
+ )
394
+ @triton.jit(do_not_specialize=['T'])
395
+ def chunk_bwd_kernel_dh_reduction(
396
+ g,
397
+ gk,
398
+ gv,
399
+ dh,
400
+ doq0,
401
+ dh0,
402
+ offsets,
403
+ chunk_offsets,
404
+ T,
405
+ HQ: tl.constexpr,
406
+ H: tl.constexpr,
407
+ K: tl.constexpr,
408
+ V: tl.constexpr,
409
+ BT: tl.constexpr,
410
+ BK: tl.constexpr,
411
+ BV: tl.constexpr,
412
+ NG: tl.constexpr,
413
+ USE_G: tl.constexpr,
414
+ USE_GK: tl.constexpr,
415
+ USE_GV: tl.constexpr,
416
+ STORE_INITIAL_STATE_GRADIENT: tl.constexpr,
417
+ USE_OFFSETS: tl.constexpr,
418
+ HEAD_FIRST: tl.constexpr
419
+ ):
420
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
421
+ i_bg = i_nh // NG
422
+ i_n, i_hq = i_nh // HQ, i_nh % HQ
423
+ i_h = i_hq // NG
424
+ if USE_OFFSETS:
425
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
426
+ T = eos - bos
427
+ NT = tl.cdiv(T, BT)
428
+ boh = tl.load(chunk_offsets + i_n).to(tl.int32)
429
+ else:
430
+ bos, eos = i_n * T, i_n * T + T
431
+ NT = tl.cdiv(T, BT)
432
+ boh = i_n * NT
433
+
434
+ # [BK, BV]
435
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
436
+ for i_t in range(NT - 1, -1, -1):
437
+ if HEAD_FIRST:
438
+ p_dh = tl.make_block_ptr(dh + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
439
+ else:
440
+ p_dh = tl.make_block_ptr(dh + ((boh+i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
441
+ b_dh += tl.load(p_dh, boundary_check=(0, 1)).to(tl.float32)
442
+ if i_t < NT - 1:
443
+ tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
444
+
445
+ last_idx = min(i_t * BT + BT, T) - 1
446
+ if USE_G:
447
+ if HEAD_FIRST:
448
+ b_g_last = tl.load(g + i_bg * T + last_idx)
449
+ else:
450
+ b_g_last = tl.load(g + (bos + last_idx) * H + i_h)
451
+ b_dh *= exp(b_g_last)
452
+
453
+ if USE_GK:
454
+ if HEAD_FIRST:
455
+ p_gk_last = gk + (i_bg * T + last_idx) * K + i_k * BK + tl.arange(0, BK)
456
+ p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK)
457
+ else:
458
+ p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
459
+
460
+ b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.)
461
+ b_dh *= exp(b_gk_last)[:, None]
462
+
463
+ if USE_GV:
464
+ if HEAD_FIRST:
465
+ p_gv_last = gv + (i_bg * T + last_idx) * V + i_v * BV + tl.arange(0, BV)
466
+ p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV)
467
+ else:
468
+ p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
469
+
470
+ b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.)
471
+ b_dh *= exp(b_gv_last)[None, :]
472
+
473
+ if STORE_INITIAL_STATE_GRADIENT:
474
+ p_doq0 = tl.make_block_ptr(doq0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
475
+ p_dh0 = tl.make_block_ptr(dh0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
476
+ b_dh += tl.load(p_doq0, boundary_check=(0, 1)).to(tl.float32)
477
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1))
478
+
479
+
480
+ def chunk_fwd_h(
481
+ k: torch.Tensor,
482
+ v: torch.Tensor,
483
+ g: torch.Tensor,
484
+ gk: torch.Tensor,
485
+ gv: torch.Tensor,
486
+ h0: torch.Tensor,
487
+ output_final_state: bool,
488
+ states_in_fp32: bool = False,
489
+ offsets: Optional[torch.Tensor] = None,
490
+ indices: Optional[torch.Tensor] = None,
491
+ head_first: bool = True,
492
+ chunk_size: int = 64
493
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
494
+ if head_first:
495
+ B, H, T, K, V = *k.shape, v.shape[-1]
496
+ else:
497
+ B, T, H, K, V = *k.shape, v.shape[-1]
498
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
499
+ # N: the actual number of sequences in the batch with either equal or variable lengths
500
+ if offsets is None:
501
+ N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
502
+ else:
503
+ if indices is None:
504
+ indices = torch.cat([torch.arange(n) for n in triton.cdiv(offsets[1:] - offsets[:-1], BT).tolist()])
505
+ indices = torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(offsets)
506
+ N, NT = len(offsets) - 1, len(indices)
507
+ chunk_offsets = torch.cat([offsets.new_tensor([0]), triton.cdiv(offsets[1:] - offsets[:-1], BT)]).cumsum(-1)
508
+
509
+ h = k.new_empty(B, H, NT, K, V, dtype=torch.float) if head_first else k.new_empty(B, NT, H, K, V, dtype=torch.float)
510
+ ht = k.new_empty(N, H, K, V, dtype=torch.float) if output_final_state else None
511
+ def grid(meta): return (triton.cdiv(K, meta['BK']) * triton.cdiv(V, meta['BV']), NT, B * H)
512
+ chunk_fwd_kernel_h_parallel[grid](
513
+ k=k,
514
+ v=v,
515
+ h=h,
516
+ g=g,
517
+ gk=gk,
518
+ gv=gv,
519
+ h0=h0,
520
+ ht=ht,
521
+ offsets=offsets,
522
+ indices=indices,
523
+ T=T,
524
+ H=H,
525
+ K=K,
526
+ V=V,
527
+ BT=BT,
528
+ USE_G=g is not None,
529
+ USE_GK=gk is not None,
530
+ USE_GV=gv is not None,
531
+ HEAD_FIRST=head_first
532
+ )
533
+ kvt, ht = ht, (torch.empty_like(ht) if output_final_state else None)
534
+ def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), N * H)
535
+ chunk_fwd_kernel_h_reduction[grid](
536
+ h=h,
537
+ g=g,
538
+ gk=gk,
539
+ gv=gv,
540
+ kvt=kvt,
541
+ ht=ht,
542
+ offsets=offsets,
543
+ chunk_offsets=chunk_offsets,
544
+ T=T,
545
+ H=H,
546
+ K=K,
547
+ V=V,
548
+ BT=BT,
549
+ USE_G=g is not None,
550
+ USE_GK=gk is not None,
551
+ USE_GV=gv is not None,
552
+ HEAD_FIRST=head_first
553
+ )
554
+ h = h.to(k.dtype) if not states_in_fp32 else h
555
+ return h, ht
556
+
557
+
558
+ def chunk_bwd_dh(
559
+ q: torch.Tensor,
560
+ k: torch.Tensor,
561
+ v: torch.Tensor,
562
+ g: torch.Tensor,
563
+ gk: torch.Tensor,
564
+ gv: torch.Tensor,
565
+ do: torch.Tensor,
566
+ h0: torch.Tensor,
567
+ dht: torch.Tensor,
568
+ scale: float,
569
+ states_in_fp32: bool = False,
570
+ offsets: Optional[torch.Tensor] = None,
571
+ indices: Optional[torch.Tensor] = None,
572
+ head_first: bool = True,
573
+ chunk_size: int = 64
574
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
575
+ if head_first:
576
+ B, H, T, K, V = *k.shape, v.shape[-1]
577
+ HQ = q.shape[1]
578
+ else:
579
+ B, T, H, K, V = *k.shape, v.shape[-1]
580
+ HQ = q.shape[2]
581
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
582
+ # N: the actual number of sequences in the batch with either equal or variable lengths
583
+ # NG: number of groups in GQA
584
+ if offsets is None:
585
+ N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
586
+ else:
587
+ if indices is None:
588
+ indices = torch.cat([torch.arange(n) for n in triton.cdiv(offsets[1:] - offsets[:-1], BT).tolist()])
589
+ indices = torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(offsets)
590
+ N, NT = len(offsets) - 1, len(indices)
591
+ chunk_offsets = torch.cat([offsets.new_tensor([0]), triton.cdiv(offsets[1:] - offsets[:-1], BT)]).cumsum(-1)
592
+ NG = HQ // H
593
+
594
+ if head_first:
595
+ dh = k.new_empty(B, HQ, NT, K, V, dtype=k.dtype if not states_in_fp32 else torch.float)
596
+ else:
597
+ dh = k.new_empty(B, NT, HQ, K, V, dtype=k.dtype if not states_in_fp32 else torch.float)
598
+ dh0 = torch.empty_like(h0, dtype=torch.float) if h0 is not None else None
599
+
600
+ def grid(meta): return (triton.cdiv(K, meta['BK']) * triton.cdiv(V, meta['BV']), NT, B * HQ)
601
+ chunk_bwd_kernel_dh_parallel[grid](
602
+ q=q,
603
+ g=g,
604
+ gk=gk,
605
+ gv=gv,
606
+ do=do,
607
+ dh=dh,
608
+ dht=dht,
609
+ dh0=dh0,
610
+ offsets=offsets,
611
+ indices=indices,
612
+ scale=scale,
613
+ T=T,
614
+ HQ=HQ,
615
+ H=H,
616
+ K=K,
617
+ V=V,
618
+ BT=BT,
619
+ NG=NG,
620
+ USE_G=g is not None,
621
+ USE_GK=gk is not None,
622
+ USE_GV=gv is not None,
623
+ HEAD_FIRST=head_first
624
+ )
625
+
626
+ doq0, dh0 = dh0, (torch.empty_like(dh0) if dh0 is not None else None)
627
+ def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), N * HQ)
628
+ chunk_bwd_kernel_dh_reduction[grid](
629
+ g=g,
630
+ gk=gk,
631
+ gv=gv,
632
+ dh=dh,
633
+ doq0=doq0,
634
+ dh0=dh0,
635
+ offsets=offsets,
636
+ chunk_offsets=chunk_offsets,
637
+ T=T,
638
+ HQ=HQ,
639
+ H=H,
640
+ K=K,
641
+ V=V,
642
+ BT=BT,
643
+ NG=NG,
644
+ USE_G=g is not None,
645
+ USE_GK=gk is not None,
646
+ USE_GV=gv is not None,
647
+ HEAD_FIRST=head_first
648
+ )
649
+ dh = dh.to(q.dtype) if not states_in_fp32 else dh
650
+ return dh, dh0
fla/ops/common/fused_recurrent.py ADDED
@@ -0,0 +1,575 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.utils import chunk_global_cumsum
11
+ from fla.ops.utils.op import exp
12
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard
13
+
14
+
15
+ @triton.heuristics({
16
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
17
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
18
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
19
+ })
20
+ @triton.autotune(
21
+ configs=[
22
+ triton.Config({}, num_warps=num_warps)
23
+ for num_warps in [1, 2, 4]
24
+ ],
25
+ key=["BK", "BV", "USE_GK", "USE_GV", "USE_G"],
26
+ )
27
+ @triton.jit(do_not_specialize=['T'])
28
+ def fused_recurrent_fwd_kernel(
29
+ q,
30
+ k,
31
+ v,
32
+ g,
33
+ gk,
34
+ gv,
35
+ o,
36
+ h0,
37
+ ht,
38
+ offsets,
39
+ scale,
40
+ T,
41
+ B: tl.constexpr,
42
+ H: tl.constexpr,
43
+ K: tl.constexpr,
44
+ V: tl.constexpr,
45
+ BK: tl.constexpr,
46
+ BV: tl.constexpr,
47
+ REVERSE: tl.constexpr,
48
+ USE_G: tl.constexpr,
49
+ USE_GK: tl.constexpr,
50
+ USE_GV: tl.constexpr,
51
+ USE_INITIAL_STATE: tl.constexpr,
52
+ STORE_FINAL_STATE: tl.constexpr,
53
+ USE_OFFSETS: tl.constexpr,
54
+ HEAD_FIRST: tl.constexpr
55
+ ):
56
+ # indices
57
+ i_v, i_k, i_nh = tl.program_id(0).to(tl.int64), tl.program_id(1).to(tl.int64), tl.program_id(2).to(tl.int64)
58
+ i_n, i_h = i_nh // H, i_nh % H
59
+ if USE_OFFSETS:
60
+ bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64)
61
+ all = T
62
+ T = eos - bos
63
+ else:
64
+ bos, eos = i_n * T, i_n * T + T
65
+ all = B * T
66
+
67
+ if HEAD_FIRST:
68
+ p_q = q + i_nh * T*K + ((T-1) * K if REVERSE else 0) + i_k * BK + tl.arange(0, BK)
69
+ p_k = k + i_nh * T*K + ((T-1) * K if REVERSE else 0) + i_k * BK + tl.arange(0, BK)
70
+ p_v = v + i_nh * T*V + ((T-1) * V if REVERSE else 0) + i_v * BV + tl.arange(0, BV)
71
+ p_o = o + (i_k * B*H + i_nh) * T*V + ((T-1) * V if REVERSE else 0) + i_v * BV + tl.arange(0, BV)
72
+ if USE_G:
73
+ p_g = g + i_nh * T + ((T-1) if REVERSE else 0)
74
+ if USE_GK:
75
+ p_gk = gk + i_nh * T*K + ((T-1) * K if REVERSE else 0) + i_k * BK + tl.arange(0, BK)
76
+ if USE_GV:
77
+ p_gv = gv + i_nh * T*V + ((T-1) * V if REVERSE else 0) + i_v * BV + tl.arange(0, BV)
78
+ else:
79
+ p_q = q + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
80
+ p_k = k + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
81
+ p_v = v + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
82
+ p_o = o + ((i_k * all + bos) + ((T-1) if REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
83
+ if USE_G:
84
+ p_g = g + (bos + ((T-1) if REVERSE else 0)) * H + i_h
85
+ if USE_GK:
86
+ p_gk = gk + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
87
+ if USE_GV:
88
+ p_gv = gv + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
89
+
90
+ mask_k = (i_k * BK + tl.arange(0, BK)) < K
91
+ mask_v = (i_v * BV + tl.arange(0, BV)) < V
92
+ mask_h = mask_k[None, :] & mask_v[:, None]
93
+ b_h = tl.zeros([BV, BK], dtype=tl.float32)
94
+
95
+ if USE_INITIAL_STATE:
96
+ p_h0 = h0 + i_nh * K*V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])
97
+ b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
98
+
99
+ for _ in range(0, T):
100
+ b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale
101
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
102
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
103
+ if USE_GK:
104
+ b_gk = tl.load(p_gk, mask=mask_k, other=0).to(tl.float32)
105
+ b_h = b_h * exp(b_gk[None, :])
106
+ if USE_GV:
107
+ b_gv = tl.load(p_gv, mask=mask_v, other=0).to(tl.float32)
108
+ b_h = b_h * exp(b_gv[:, None])
109
+ if USE_G:
110
+ b_g = tl.load(p_g).to(tl.float32)
111
+ b_h = b_h * exp(b_g)
112
+ b_h += b_k[None, :] * b_v[:, None]
113
+ b_o = b_h * b_q[None, :]
114
+ b_o = tl.sum(b_o, axis=1)
115
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
116
+ p_q += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K
117
+ p_k += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K
118
+ p_v += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V
119
+ p_o += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V
120
+ if USE_GK:
121
+ p_gk += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K
122
+ if USE_GV:
123
+ p_gv += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V
124
+ if USE_G:
125
+ p_g += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H)
126
+
127
+ if STORE_FINAL_STATE:
128
+ p_ht = ht + i_nh * K*V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])
129
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
130
+
131
+
132
+ @triton.heuristics({
133
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
134
+ 'STORE_INITIAL_STATE_GRADIENT': lambda args: args['dh0'] is not None,
135
+ 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None,
136
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
137
+ })
138
+ @triton.autotune(
139
+ configs=[
140
+ triton.Config({}, num_warps=num_warps)
141
+ for num_warps in [1, 2, 4]
142
+ ],
143
+ key=['BK', 'BV', 'USE_GK', 'USE_GV', 'USE_G'],
144
+ )
145
+ @triton.jit(do_not_specialize=['T'])
146
+ def fused_recurrent_bwd_kernel(
147
+ q,
148
+ k,
149
+ v,
150
+ g,
151
+ gk,
152
+ gv,
153
+ h0,
154
+ do,
155
+ dq,
156
+ dk,
157
+ dv,
158
+ dht,
159
+ dh0,
160
+ offsets,
161
+ scale,
162
+ T,
163
+ B: tl.constexpr,
164
+ H: tl.constexpr,
165
+ K: tl.constexpr,
166
+ V: tl.constexpr,
167
+ BK: tl.constexpr,
168
+ BV: tl.constexpr,
169
+ REVERSE: tl.constexpr,
170
+ USE_G: tl.constexpr,
171
+ USE_GK: tl.constexpr,
172
+ USE_GV: tl.constexpr,
173
+ USE_INITIAL_STATE: tl.constexpr,
174
+ STORE_INITIAL_STATE_GRADIENT: tl.constexpr,
175
+ USE_FINAL_STATE_GRADIENT: tl.constexpr,
176
+ USE_OFFSETS: tl.constexpr,
177
+ HEAD_FIRST: tl.constexpr
178
+ ):
179
+ i_v, i_k, i_nh = tl.program_id(0).to(tl.int64), tl.program_id(1).to(tl.int64), tl.program_id(2).to(tl.int64)
180
+ i_n, i_h = i_nh // H, i_nh % H
181
+ if USE_OFFSETS:
182
+ bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64)
183
+ all = T
184
+ T = eos - bos
185
+ else:
186
+ bos, eos = i_n * T, i_n * T + T
187
+ all = B * T
188
+
189
+ if HEAD_FIRST:
190
+ p_k = k + i_nh * T*K + ((T-1) * K if REVERSE else 0) + i_k * BK + tl.arange(0, BK)
191
+ p_v = v + i_nh * T*V + ((T-1) * V if REVERSE else 0) + i_v * BV + tl.arange(0, BV)
192
+ p_do = do + i_nh * T*V + ((T-1) * V if REVERSE else 0) + i_v * BV + tl.arange(0, BV)
193
+ p_dq = dq + (i_v * B*H + i_nh) * T*K + ((T-1) * K if REVERSE else 0) + i_k * BK + tl.arange(0, BK)
194
+ if USE_G:
195
+ p_g = g + i_nh * T + ((T-1) if REVERSE else 0)
196
+ if USE_GK:
197
+ p_gk = gk + i_nh * T*K + ((T-1) * K if REVERSE else 0) + i_k * BK + tl.arange(0, BK)
198
+ if USE_GV:
199
+ p_gv = gv + i_nh * T*V + ((T-1) * V if REVERSE else 0) + i_v * BV + tl.arange(0, BV)
200
+ else:
201
+ p_k = k + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
202
+ p_v = v + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
203
+ p_do = do + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
204
+ p_dq = dq + ((i_v * all + bos) + ((T-1) if REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
205
+ if USE_G:
206
+ p_g = g + (bos + ((T-1) if REVERSE else 0)) * H + i_h
207
+ if USE_GK:
208
+ p_gk = gk + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
209
+ if USE_GV:
210
+ p_gv = gv + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
211
+
212
+ mask_k = i_k * BK + tl.arange(0, BK) < K
213
+ mask_v = i_v * BV + tl.arange(0, BV) < V
214
+ mask_h = mask_k[:, None] & mask_v[None, :]
215
+
216
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
217
+ if USE_INITIAL_STATE:
218
+ p_h0 = h0 + i_nh * K*V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :])
219
+ b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
220
+
221
+ for _ in range(0, T):
222
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
223
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
224
+ b_do = tl.load(p_do, mask=mask_v, other=0).to(tl.float32)
225
+ if USE_G:
226
+ b_g = tl.load(p_g).to(tl.float32)
227
+ b_h = b_h * exp(b_g)
228
+ if USE_GK:
229
+ b_gk = tl.load(p_gk, mask=mask_k, other=0).to(tl.float32)
230
+ b_h = b_h * exp(b_gk[:, None])
231
+ if USE_GV:
232
+ b_gv = tl.load(p_gv, mask=mask_v, other=0).to(tl.float32)
233
+ b_h = b_h * exp(b_gv[None, :])
234
+ b_h += b_k[:, None] * b_v[None, :]
235
+ b_dq = b_h * b_do[None, :]
236
+ b_dq = tl.sum(b_dq, axis=1) * scale
237
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), mask=mask_k)
238
+
239
+ p_k += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K
240
+ p_v += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V
241
+ p_do += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V
242
+ p_dq += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K
243
+ if USE_G:
244
+ p_g += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H)
245
+ if USE_GK:
246
+ p_gk += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K
247
+ if USE_GV:
248
+ p_gv += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V
249
+
250
+ # sync threads
251
+ tl.debug_barrier()
252
+
253
+ if HEAD_FIRST:
254
+ p_q = q + i_nh * T*K + ((T - 1) * K if not REVERSE else 0) + i_k * BK + tl.arange(0, BK)
255
+ p_k = k + i_nh * T*K + ((T - 1) * K if not REVERSE else 0) + i_k * BK + tl.arange(0, BK)
256
+ p_v = v + i_nh * T*V + ((T - 1) * V if not REVERSE else 0) + i_v * BV + tl.arange(0, BV)
257
+ p_do = do + i_nh * T*V + ((T - 1) * V if not REVERSE else 0) + i_v * BV + tl.arange(0, BV)
258
+ p_dk = dk + (i_v * B*H + i_nh) * T*K + ((T - 1) * K if not REVERSE else 0) + i_k * BK + tl.arange(0, BK)
259
+ p_dv = dv + (i_k * B*H + i_nh) * T*V + ((T - 1) * V if not REVERSE else 0) + i_v * BV + tl.arange(0, BV)
260
+ if USE_G:
261
+ p_g = g + i_nh * T + ((T - 1) if not REVERSE else 0)
262
+ if USE_GK:
263
+ p_gk = gk + i_nh * T*K + ((T - 1) * K if not REVERSE else 0) + i_k * BK + tl.arange(0, BK)
264
+ if USE_GV:
265
+ p_gv = gv + i_nh * T*V + ((T - 1) * V if not REVERSE else 0) + i_v * BV + tl.arange(0, BV)
266
+ else:
267
+ p_q = q + (bos + ((T - 1) if not REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
268
+ p_k = k + (bos + ((T - 1) if not REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
269
+ p_v = v + (bos + ((T - 1) if not REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
270
+ p_do = do + (bos + ((T - 1) if not REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
271
+ p_dk = dk + ((i_v * all + bos) + ((T - 1) if not REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
272
+ p_dv = dv + ((i_k * all + bos) + ((T - 1) if not REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
273
+ if USE_G:
274
+ p_g = g + (bos + ((T - 1) if not REVERSE else 0)) * H + i_h
275
+ if USE_GK:
276
+ p_gk = gk + (bos + ((T - 1) if not REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
277
+ if USE_GV:
278
+ p_gv = gv + (bos + ((T - 1) if not REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
279
+
280
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
281
+ if USE_FINAL_STATE_GRADIENT:
282
+ p_dht = dht + i_nh * K*V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :])
283
+ b_dh += tl.load(p_dht, mask=mask_h, other=0).to(tl.float32)
284
+
285
+ for _ in range(T):
286
+ b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale
287
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
288
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
289
+ b_do = tl.load(p_do, mask=mask_v, other=0).to(tl.float32)
290
+ b_dh += b_q[:, None] * b_do[None, :]
291
+ b_dk = tl.sum(b_dh * b_v[None, :], axis=1)
292
+ b_dv = tl.sum(b_dh * b_k[:, None], axis=0)
293
+ if USE_G:
294
+ b_g = tl.load(p_g).to(tl.float32)
295
+ b_dh *= exp(b_g)
296
+ if USE_GK:
297
+ b_gk = tl.load(p_gk, mask=mask_k, other=0).to(tl.float32)
298
+ b_dh *= exp(b_gk)[:, None]
299
+ if USE_GV:
300
+ b_gv = tl.load(p_gv, mask=mask_v, other=0).to(tl.float32)
301
+ b_dh *= exp(b_gv)[None, :]
302
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), mask=mask_k)
303
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), mask=mask_v)
304
+
305
+ p_q += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) * K
306
+ p_k += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) * K
307
+ p_v += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) * V
308
+ p_do += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) * V
309
+ p_dk += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) * K
310
+ p_dv += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) * V
311
+ if USE_G:
312
+ p_g += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H)
313
+ if USE_GK:
314
+ p_gk += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) * K
315
+ if USE_GV:
316
+ p_gv += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) * V
317
+
318
+ if STORE_INITIAL_STATE_GRADIENT:
319
+ p_dh0 = dh0 + i_nh * K*V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :])
320
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), mask=mask_h)
321
+
322
+
323
+ def fused_recurrent_fwd(
324
+ q: torch.Tensor,
325
+ k: torch.Tensor,
326
+ v: torch.Tensor,
327
+ g: Optional[torch.Tensor] = None,
328
+ gk: Optional[torch.Tensor] = None,
329
+ gv: Optional[torch.Tensor] = None,
330
+ scale: Optional[float] = None,
331
+ initial_state: Optional[torch.Tensor] = None,
332
+ output_final_state: bool = False,
333
+ reverse: bool = False,
334
+ offsets: Optional[torch.LongTensor] = None,
335
+ head_first: bool = True
336
+ ):
337
+ if head_first:
338
+ B, H, T, K, V = *k.shape, v.shape[-1]
339
+ else:
340
+ B, T, H, K, V = *k.shape, v.shape[-1]
341
+ N = B if offsets is None else len(offsets) - 1
342
+ BK, BV = min(K, 64), min(V, 64)
343
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
344
+
345
+ h0 = initial_state
346
+ if output_final_state:
347
+ ht = q.new_empty(N, H, K, V, dtype=torch.float32)
348
+ else:
349
+ ht = None
350
+ o = q.new_empty(NK, *v.shape, dtype=torch.float32)
351
+
352
+ grid = (NV, NK, N * H)
353
+ fused_recurrent_fwd_kernel[grid](
354
+ q,
355
+ k,
356
+ v,
357
+ g,
358
+ gk,
359
+ gv,
360
+ o,
361
+ h0,
362
+ ht,
363
+ offsets,
364
+ scale,
365
+ T=T,
366
+ B=B,
367
+ H=H,
368
+ K=K,
369
+ V=V,
370
+ BK=BK,
371
+ BV=BV,
372
+ USE_G=g is not None,
373
+ USE_GK=gk is not None,
374
+ USE_GV=gv is not None,
375
+ REVERSE=reverse,
376
+ HEAD_FIRST=head_first
377
+ )
378
+ o = o.sum(0)
379
+ return o, ht
380
+
381
+
382
+ def fused_recurrent_bwd(
383
+ q: torch.Tensor,
384
+ k: torch.Tensor,
385
+ v: torch.Tensor,
386
+ g: Optional[torch.Tensor] = None,
387
+ gk: Optional[torch.Tensor] = None,
388
+ gv: Optional[torch.Tensor] = None,
389
+ o: Optional[torch.Tensor] = None,
390
+ do: Optional[torch.Tensor] = None,
391
+ dht: Optional[torch.Tensor] = None,
392
+ scale: Optional[float] = None,
393
+ initial_state: Optional[torch.Tensor] = None,
394
+ reverse: bool = False,
395
+ offsets: Optional[torch.LongTensor] = None,
396
+ head_first: bool = True
397
+ ):
398
+ if head_first:
399
+ B, H, T, K, V = *k.shape, v.shape[-1]
400
+ else:
401
+ B, T, H, K, V = *k.shape, v.shape[-1]
402
+ N = B if offsets is None else len(offsets) - 1
403
+
404
+ BK, BV = min(K, 64), min(V, 64)
405
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
406
+
407
+ dq = q.new_empty(NV, *q.shape, dtype=torch.float32)
408
+ dk = q.new_empty(NV, *k.shape, dtype=torch.float32)
409
+ dv = q.new_empty(NK, *v.shape, dtype=torch.float32)
410
+ h0 = initial_state
411
+ dh0 = torch.empty_like(initial_state) if initial_state is not None else None
412
+
413
+ grid = (NV, NK, N * H)
414
+ fused_recurrent_bwd_kernel[grid](
415
+ q,
416
+ k,
417
+ v,
418
+ g,
419
+ gk,
420
+ gv,
421
+ h0,
422
+ do,
423
+ dq,
424
+ dk,
425
+ dv,
426
+ dht,
427
+ dh0,
428
+ offsets,
429
+ scale,
430
+ B=B,
431
+ T=T,
432
+ H=H,
433
+ K=K,
434
+ V=V,
435
+ BK=BK,
436
+ BV=BV,
437
+ USE_G=g is not None,
438
+ USE_GK=gk is not None,
439
+ USE_GV=gv is not None,
440
+ REVERSE=reverse,
441
+ HEAD_FIRST=head_first
442
+ )
443
+ dq = dq.sum(0)
444
+ dk = dk.sum(0)
445
+ dv = dv.sum(0)
446
+ dg, dgk, dgv = None, None, None
447
+ if g is not None:
448
+ dg = chunk_global_cumsum(
449
+ (dq * q.float() - dk * k.float()).sum(-1),
450
+ reverse=not reverse,
451
+ offsets=offsets,
452
+ head_first=head_first
453
+ )
454
+ if gk is not None:
455
+ dgk = chunk_global_cumsum(
456
+ dq * q.float() - dk * k.float(),
457
+ reverse=not reverse,
458
+ offsets=offsets,
459
+ head_first=head_first
460
+ )
461
+ if gv is not None:
462
+ dgv = chunk_global_cumsum(
463
+ do.float() * o.float() - dv * v.float(),
464
+ reverse=not reverse,
465
+ offsets=offsets,
466
+ head_first=head_first
467
+ )
468
+
469
+ return dq, dk, dv, dg, dgk, dgv, dh0
470
+
471
+
472
+ class FusedRecurrentFunction(torch.autograd.Function):
473
+
474
+ @staticmethod
475
+ @input_guard
476
+ @autocast_custom_fwd
477
+ def forward(
478
+ ctx,
479
+ q: torch.Tensor,
480
+ k: torch.Tensor,
481
+ v: torch.Tensor,
482
+ g: Optional[torch.Tensor] = None,
483
+ gk: Optional[torch.Tensor] = None,
484
+ gv: Optional[torch.Tensor] = None,
485
+ scale: Optional[float] = None,
486
+ initial_state: Optional[torch.Tensor] = None,
487
+ output_final_state: bool = False,
488
+ reverse: bool = False,
489
+ offsets: Optional[torch.LongTensor] = None,
490
+ head_first: bool = True
491
+ ):
492
+ o, ht = fused_recurrent_fwd(
493
+ q=q,
494
+ k=k,
495
+ v=v,
496
+ g=g,
497
+ gk=gk,
498
+ gv=gv,
499
+ scale=scale,
500
+ initial_state=initial_state,
501
+ output_final_state=output_final_state,
502
+ reverse=reverse,
503
+ offsets=offsets,
504
+ head_first=head_first
505
+ )
506
+ ctx.save_for_backward(q, k, v, g, gk, gv, initial_state, o)
507
+ ctx.scale = scale
508
+ ctx.reverse = reverse
509
+ ctx.offsets = offsets
510
+ ctx.head_first = head_first
511
+ return o.to(q.dtype), ht
512
+
513
+ @staticmethod
514
+ @input_guard
515
+ @autocast_custom_bwd
516
+ def backward(ctx, do, dht):
517
+ q, k, v, g, gk, gv, initial_state, o = ctx.saved_tensors
518
+ # not supported yet.
519
+ if dht is not None:
520
+ if not dht.eq(0).all():
521
+ if g is not None:
522
+ assert g.requires_grad is False, "Cannot load final state gradient and use gates at the same time"
523
+ if gk is not None:
524
+ assert gk.requires_grad is False, "Cannot load final state gradient and use gates at the same time"
525
+ if gv is not None:
526
+ assert gv.requires_grad is False, "Cannot load final state gradient and use gates at the same time"
527
+ dq, dk, dv, dg, dgk, dgv, dh0 = fused_recurrent_bwd(
528
+ q=q,
529
+ k=k,
530
+ v=v,
531
+ g=g,
532
+ gk=gk,
533
+ gv=gv,
534
+ o=o,
535
+ do=do,
536
+ dht=dht,
537
+ scale=ctx.scale,
538
+ initial_state=initial_state,
539
+ reverse=ctx.reverse,
540
+ offsets=ctx.offsets,
541
+ head_first=ctx.head_first
542
+ )
543
+ return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dg, dgk, dgv, None, dh0, None, None, None, None
544
+
545
+
546
+ def fused_recurrent(
547
+ q: torch.Tensor,
548
+ k: torch.Tensor,
549
+ v: torch.Tensor,
550
+ g: Optional[torch.Tensor] = None,
551
+ gk: Optional[torch.Tensor] = None,
552
+ gv: Optional[torch.Tensor] = None,
553
+ scale: Optional[float] = None,
554
+ initial_state: Optional[torch.Tensor] = None,
555
+ output_final_state: bool = False,
556
+ reverse: bool = False,
557
+ cu_seqlens: Optional[torch.LongTensor] = None,
558
+ head_first: bool = True
559
+ ):
560
+ if scale is None:
561
+ scale = k.shape[-1] ** -0.5
562
+ return FusedRecurrentFunction.apply(
563
+ q,
564
+ k,
565
+ v,
566
+ g,
567
+ gk,
568
+ gv,
569
+ scale,
570
+ initial_state,
571
+ output_final_state,
572
+ reverse,
573
+ cu_seqlens,
574
+ head_first
575
+ )
fla/ops/common/utils.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ import torch
5
+ import triton
6
+ import triton.language as tl
7
+
8
+ from fla.utils import tensor_cache
9
+
10
+
11
+ @triton.autotune(
12
+ configs=[
13
+ triton.Config({}, num_warps=num_warps)
14
+ for num_warps in [4, 8, 16, 32]
15
+ ],
16
+ key=['B'],
17
+ )
18
+ @triton.jit
19
+ def prepare_position_ids_kernel(
20
+ y,
21
+ offsets,
22
+ B: tl.constexpr
23
+ ):
24
+ i_n = tl.program_id(0)
25
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
26
+ T = eos - bos
27
+
28
+ o = tl.arange(0, B)
29
+ for i in range(0, tl.cdiv(T, B) * B, B):
30
+ o_i = o + i
31
+ tl.store(y + bos + o_i, o_i, o_i < T)
32
+
33
+
34
+ @tensor_cache
35
+ def prepare_lens(offsets: torch.LongTensor) -> torch.LongTensor:
36
+ return offsets[1:] - offsets[:-1]
37
+
38
+
39
+ @tensor_cache
40
+ def prepare_position_ids(offsets: torch.LongTensor) -> torch.LongTensor:
41
+ return torch.cat([torch.arange(n, dtype=offsets.dtype, device=offsets.device) for n in prepare_lens(offsets).unbind()])
42
+
43
+
44
+ @tensor_cache
45
+ def prepare_sequence_ids(position_ids: torch.LongTensor) -> torch.LongTensor:
46
+ return position_ids.eq(0).cumsum(0) - 1
47
+
48
+
49
+ @tensor_cache
50
+ def prepare_token_indices(offsets: torch.LongTensor) -> torch.LongTensor:
51
+ position_ids = prepare_position_ids(offsets)
52
+ return torch.stack([prepare_sequence_ids(position_ids), position_ids], 1).to(offsets)
53
+
54
+
55
+ @tensor_cache
56
+ def prepare_chunk_indices(
57
+ offsets: torch.LongTensor,
58
+ chunk_size: int
59
+ ) -> torch.LongTensor:
60
+ indices = torch.cat([torch.arange(n) for n in triton.cdiv(prepare_lens(offsets), chunk_size).tolist()])
61
+ return torch.stack([prepare_sequence_ids(indices), indices], 1).to(offsets)
62
+
63
+
64
+ @tensor_cache
65
+ def prepare_chunk_offsets(
66
+ offsets: torch.LongTensor,
67
+ chunk_size: int
68
+ ) -> torch.LongTensor:
69
+ return torch.cat([offsets.new_tensor([0]), triton.cdiv(prepare_lens(offsets), chunk_size)]).cumsum(-1)
fla/ops/delta_rule/README.md ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Chunkwise-form Parallelism of DeltaNet
2
+
3
+ This section expands on the formulation presented in Appendix B of the DeltaNet paper.[^1]
4
+
5
+ To reduce notational clutter, we focus on the first chunk, denoting $\mathbf{S}^r=\mathbf{S}_{[1]}^r$. By partially expanding the recurrence, we have:
6
+ ```math
7
+ \begin{equation}
8
+ \begin{aligned}
9
+ \mathbf{S}^r &= \underbrace{\left(\prod_{i=1}^r \mathbf{I} - \beta^i \boldsymbol{k}^i \boldsymbol{k}^{i\top} \right)}_{:= \mathbf{P}^r} \cdot\mathbf{S}^{0} + \overbrace{\sum_{i=1}^{r} \underbrace{\left(\prod_{j=i+1}^r \mathbf{I} - \beta^j \boldsymbol{k}^j \boldsymbol{k}^{j\top} \right)}_{:= \mathbf{P}_{i+1}^r}\beta^i \boldsymbol{k}^i\boldsymbol{v}^{i\top}}^{:=\mathbf{H}^r} \\
10
+ &=\mathbf{P}^r \cdot \mathbf{S}^{0} + \mathbf{H}^r
11
+ \end{aligned}
12
+ \end{equation}
13
+ ```
14
+
15
+ where $\mathbf{P}_i^r$ involves cumulative products of generalized Householder matrices.
16
+ We abbreviate $\mathbf{P}_1^r$ as $\mathbf{P}^r$.
17
+ This can be optimized using the classical WY representation:
18
+ ```math
19
+ \begin{equation}
20
+ \mathbf{P}^{r} = \mathbf{I} - \sum_{i=1}^{r}\boldsymbol{k}^i\boldsymbol{w}^{i\top} \in \mathbb{R}^{d_k \times d_k};\qquad
21
+ \boldsymbol{w}^r = \beta^r \left(\boldsymbol{k}^r - \sum_{i=1}^{r-1} \left(\boldsymbol{k}^{r\top}\boldsymbol{k}^i \right)\boldsymbol{w}^i \right) \in \mathbb{R}^{d_k}
22
+ \end{equation}
23
+ ```
24
+
25
+ We prove this by induction:
26
+ ```math
27
+ \begin{align*}
28
+ \mathbf{P}^{r} &= \prod_{i=1}^r \mathbf{I} - \beta^i \boldsymbol{k}^i \boldsymbol{k}^{i\top} \\
29
+ &= \left(\mathbf{I} - \beta^r \boldsymbol{k}^r \boldsymbol{k}^{r\top}\right)\mathbf{P}^{r-1} \\
30
+ &= \left(\mathbf{I} - \beta^r \boldsymbol{k}^r \boldsymbol{k}^{r\top}\right)\left(\mathbf{I} - \sum_{i=1}^{r-1}\boldsymbol{k}^i\boldsymbol{w}^{i\top}\right) \\
31
+ &= \mathbf{I} - \sum_{i=1}^{r-1}\boldsymbol{k}^i\boldsymbol{w}^{i\top} - \beta^r \boldsymbol{k}^r \boldsymbol{k}^{r\top} + \beta^r\boldsymbol{k}^r \boldsymbol{k}^{r\top} \left(\sum_{i=1}^{r-1}\boldsymbol{k}^i\boldsymbol{w}^{i\top}\right) \\
32
+ &= \mathbf{I} - \sum_{i=1}^{r-1}\boldsymbol{k}^i\boldsymbol{w}^{i\top} - \beta^r \boldsymbol{k}^r \left(\boldsymbol{k}^{r} - \left(\sum_{i=1}^{r-1}\left(\boldsymbol{k}^{r\top} \boldsymbol{k}^i\right)\boldsymbol{w}^{i}\right) \right)^\top \\
33
+ &= \mathbf{I} - \sum_{i=1}^{r}\boldsymbol{k}^i\boldsymbol{w}^{i\top}
34
+ \end{align*}
35
+ ```
36
+
37
+ Similarly, $\mathbf{H}^r$ can be represented as:
38
+ ```math
39
+ \begin{equation}
40
+ \mathbf{H}^{r} = \sum_{i=1}^{r} \boldsymbol{k}^i \boldsymbol{u}^{i\top} \in \mathbb{R}^{d_k \times d_v};\qquad \boldsymbol{u}^r = \beta^r \left(\boldsymbol{v}^r - \sum_{i=1}^{r-1} \left(\boldsymbol{k}^{r\top}\boldsymbol{k}^i\right) \boldsymbol{u}^i \right)\in \mathbb{R}^{d_v}
41
+ \end{equation}
42
+ ```
43
+
44
+ This can also be proven by induction:
45
+ ```math
46
+ \begin{align*}
47
+ \mathbf{H}^{r} &= \sum_{i=1}^{r} \mathbf{P}_{i+1}^r \beta^i \boldsymbol{k}^i \boldsymbol{v}^{i\top}\\
48
+ &= \left(\mathbf{I} - \beta^r \boldsymbol{k}^r \boldsymbol{k}^{r\top}\right) \mathbf{H}^{r-1} + \beta^r \boldsymbol{k}^r \boldsymbol{v}^{r\top}\\
49
+ &= \sum_{i=1}^{r-1}\boldsymbol{k}^i \boldsymbol{u}^{i\top} - \beta^r \boldsymbol{k}^r \boldsymbol{k}^{r\top} \sum_{i=1}^{r-1}\boldsymbol{k}^i \boldsymbol{u}^{i\top} +\beta^r \boldsymbol{k}^r \boldsymbol{v}^{r\top}\\
50
+ &= \sum_{i=1}^{r-1}\boldsymbol{k}^i \boldsymbol{u}^{i\top} + \boldsymbol{k}^r \left(\beta^r \boldsymbol{v}^{r\top}-\beta^r \boldsymbol{k}^{r\top} \sum_{i=1}^{r-1}\boldsymbol{k}^i \boldsymbol{u}^{i\top}\right) \\
51
+ &= \sum_{i=1}^{r-1}\boldsymbol{k}^i \boldsymbol{u}^{i\top} + \boldsymbol{k}^r \beta^r\left(\boldsymbol{v}^{r}-\sum_{i=1}^{r-1}\left(\boldsymbol{k}^{r\top}\boldsymbol{k}^{i}\right)\boldsymbol{u}^{i} \right)^\top \\
52
+ &=\sum_{i=1}^{r} \boldsymbol{k}^i \boldsymbol{u}^{i\top}
53
+ \end{align*}
54
+ ```
55
+
56
+ In matrix form, $\mathbf{P}$ and $\mathbf{H}$ can be written as:
57
+ ```math
58
+ \begin{equation}
59
+ \mathbf{P}=\mathbf{I}-\mathbf{K}^\top\mathbf{W} \in \mathbb{R}^{d_k \times d_k}, \qquad\mathbf{H}=\mathbf{K}^\top\mathbf{U} \in \mathbb{R}^{d_k\times d_v}
60
+ \end{equation}
61
+ ```
62
+
63
+ Now we can derive the matrix form of $\mathbf{W}$ and $\mathbf{U}$:
64
+ ```math
65
+ \begin{align*}
66
+ \mathbf{W} &= \mathrm{diag}(\beta) \mathbf{K} - \mathrm{tril}(\mathrm{diag}(\beta) \mathbf{K}\mathbf{K}^\top, -1)\mathbf{W}\\
67
+ \left(\mathbf{I} + \mathrm{tril}(\mathrm{diag}(\beta) \mathbf{K}\mathbf{K}^\top, -1)\right) \mathbf{W} &= \mathrm{diag}(\beta) \mathbf{K}
68
+ \end{align*}
69
+ ```
70
+ A similar process holds for $\mathbf{U}$. We can further write $\mathbf{W}$ and $\mathbf{U}$ in matrix form:
71
+ ```math
72
+ \begin{align*}
73
+ \mathbf{T} &= \left(\mathbf{I} + \mathrm{tril}\left(\mathrm{diag}(\beta)\mathbf{K} \mathbf{K}^\top,-1\right)\right)^{-1}\mathrm{diag}\left(\beta\right)\in \mathbb{R}^{C \times C}\\
74
+ \mathbf{W} &= \mathbf{T} \mathbf{K}\in \mathbb{R}^{C \times d_k}\\
75
+ \mathbf{U} &= \mathbf{T}\mathbf{V}\in \mathbb{R}^{C \times d_v}
76
+ \end{align*}
77
+ ```
78
+
79
+ Substituting these back into the original equations yields a hardware-efficient chunkwise algorithm for DeltaNet that leverages matrix multiplications, enabling tensor core based GPU optimization:
80
+ ```math
81
+ \begin{equation}
82
+ \begin{aligned}
83
+ \mathbf{S} &= \mathbf{P}\cdot\mathbf{S}^0 + \mathbf{H} \\
84
+ &= \mathbf{S}^0 + \mathbf{K}^\top (\mathbf{U} -\mathbf{W} \mathbf{S}^0) \in \mathbb{R}^{d_k \times d_v}\\
85
+ \mathbf{O} &= \mathbf{Q} \mathbf{S}^0 + (\mathbf{Q} \mathbf{K}^{\top} \odot \mathbf{M}) \left(\mathbf{U} - \mathbf{W} \mathbf{S}^0\right) \in \mathbb{R}^{C \times d_v}
86
+ \end{aligned}
87
+ \end{equation}
88
+ ```
89
+
90
+ [^1]: https://arxiv.org/abs/2406.06484
fla/ops/delta_rule/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .chunk import chunk_delta_rule
4
+ from .fused_chunk import fused_chunk_delta_rule
5
+ from .fused_recurrent import fused_recurrent_delta_rule
6
+
7
+ __all__ = [
8
+ 'fused_chunk_delta_rule',
9
+ 'fused_recurrent_delta_rule',
10
+ 'chunk_delta_rule'
11
+ ]
fla/ops/delta_rule/parallel.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+ from einops import rearrange
10
+
11
+ from fla.ops.delta_rule.wy_fast import fwd_prepare_T
12
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard
13
+
14
+
15
+ @triton.autotune(
16
+ configs=[
17
+ triton.Config({}, num_warps=num_warps)
18
+ for num_warps in [1, 2, 4]
19
+ ],
20
+ key=['BT', 'K', 'V'],
21
+ )
22
+ @triton.jit(do_not_specialize=['T'])
23
+ def chunk_transform_qk_fwd_kernel(
24
+ q,
25
+ k,
26
+ v,
27
+ beta,
28
+ o,
29
+ A,
30
+ q_new,
31
+ k_new,
32
+ A_local,
33
+ scale,
34
+ T,
35
+ K: tl.constexpr,
36
+ V: tl.constexpr,
37
+ BK: tl.constexpr,
38
+ BV: tl.constexpr,
39
+ BT: tl.constexpr,
40
+ OUTPUT_ATTENTIONS: tl.constexpr
41
+ ):
42
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
43
+
44
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
45
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
46
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, 0), (BT, BV), (1, 0))
47
+ b_q = (tl.load(p_q, boundary_check=(0, 1)) * scale).to(p_q.dtype.element_ty)
48
+ b_k = tl.load(p_k, boundary_check=(0, 1))
49
+ b_v = tl.load(p_v, boundary_check=(0, 1))
50
+
51
+ p_T = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
52
+ b_T = tl.load(p_T, boundary_check=(0, 1))
53
+
54
+ o_i = tl.arange(0, BT)
55
+ m_t = o_i[:, None] >= o_i[None, :]
56
+ b_qk = tl.where(m_t, tl.dot(b_q, tl.trans(b_k), allow_tf32=False), 0).to(b_q.dtype)
57
+ m_t = o_i[:, None] > o_i[None, :]
58
+ b_kk = tl.where(m_t, tl.dot(b_k, tl.trans(b_k), allow_tf32=False), 0).to(b_k.dtype)
59
+
60
+ p_beta = tl.make_block_ptr(beta + i_bh * T, (T, ), (1, ), (i_t * BT, ), (BT, ), (0, ))
61
+ b_beta = tl.load(p_beta, boundary_check=(0, ))
62
+ b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype)
63
+
64
+ b_qkT = tl.dot(b_qk, b_T, allow_tf32=False).to(b_k.dtype)
65
+
66
+ if OUTPUT_ATTENTIONS:
67
+ p_a = tl.make_block_ptr(A_local + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
68
+ tl.store(p_a, b_qkT.to(p_a.dtype.element_ty), boundary_check=(0, 1))
69
+
70
+ b_kkT = tl.dot(b_kk, b_T, allow_tf32=False).to(b_k.dtype)
71
+ p_o = tl.make_block_ptr(o + i_bh * T*V, (T, V), (V, 1), (i_t * BT, 0), (BT, BV), (1, 0))
72
+ tl.store(p_o, tl.dot(b_qkT, b_v).to(p_o.dtype.element_ty), boundary_check=(0, 1))
73
+
74
+ p_q_new = tl.make_block_ptr(q_new + i_bh * T*K, (T, K), (K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
75
+ tl.store(p_q_new, (b_q - tl.dot(b_qkT, b_k_beta, allow_tf32=False)).to(p_q_new.dtype.element_ty), boundary_check=(0, 1))
76
+
77
+ p_k_new = tl.make_block_ptr(k_new + i_bh * T*K, (T, K), (K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
78
+ b_k_new = b_k - tl.dot(tl.trans(b_kkT), b_k_beta, allow_tf32=False)
79
+ tl.store(p_k_new, b_k_new.to(p_k_new.dtype.element_ty), boundary_check=(0, 1))
80
+
81
+
82
+ def chunk_transform_qk_fwd(
83
+ q: torch.Tensor,
84
+ k: torch.Tensor,
85
+ v: torch.Tensor,
86
+ beta: torch.Tensor,
87
+ A: torch.Tensor,
88
+ scale: float,
89
+ chunk_size: int,
90
+ output_attentions: bool
91
+ ):
92
+ B, H, T, K = k.shape
93
+ BT = chunk_size
94
+ q_new = torch.empty_like(q)
95
+ k_new = torch.empty_like(k)
96
+ o = torch.empty_like(v)
97
+ grid = (triton.cdiv(T, BT), B*H)
98
+ V = v.shape[-1]
99
+ A_local = torch.empty_like(A) if output_attentions else None
100
+ chunk_transform_qk_fwd_kernel[grid](
101
+ q,
102
+ k,
103
+ v,
104
+ beta,
105
+ o,
106
+ A,
107
+ q_new,
108
+ k_new,
109
+ A_local,
110
+ scale=scale,
111
+ T=T,
112
+ K=K,
113
+ V=V,
114
+ BT=BT,
115
+ BK=triton.next_power_of_2(K),
116
+ BV=triton.next_power_of_2(V),
117
+ OUTPUT_ATTENTIONS=output_attentions
118
+ )
119
+ return q_new, k_new, o, A_local
120
+
121
+
122
+ @triton.autotune(
123
+ configs=[
124
+ triton.Config({}, num_warps=1),
125
+ triton.Config({}, num_warps=2),
126
+ ],
127
+ key=['BT'],
128
+ )
129
+ @triton.jit(do_not_specialize=['T'])
130
+ def save_intra_chunk_attn(
131
+ A,
132
+ A_local,
133
+ T,
134
+ BT: tl.constexpr,
135
+ ):
136
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
137
+ p_A = tl.make_block_ptr(A + i_bh * T * T, (T, T), (T, 1), (i_t * BT, i_t * BT), (BT, BT), (1, 0))
138
+ p_A_local = tl.make_block_ptr(A_local + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
139
+ b_A_local = tl.load(p_A_local, boundary_check=(0, 1))
140
+ tl.store(p_A, b_A_local.to(p_A.dtype.element_ty), boundary_check=(0, 1))
141
+
142
+
143
+ @triton.heuristics({
144
+ 'OUTPUT_ATTENTIONS': lambda args: args['attn'] is not None
145
+ })
146
+ @triton.jit(do_not_specialize=['T'])
147
+ def parallel_delta_rule_fwd_kernel(
148
+ q,
149
+ k,
150
+ k2, # original k
151
+ v,
152
+ beta,
153
+ o,
154
+ o_new,
155
+ attn,
156
+ T,
157
+ K: tl.constexpr,
158
+ V: tl.constexpr,
159
+ BT: tl.constexpr,
160
+ BS: tl.constexpr,
161
+ BK: tl.constexpr,
162
+ BV: tl.constexpr,
163
+ OUTPUT_ATTENTIONS: tl.constexpr
164
+ ):
165
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
166
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT, 0), (BT, BK), (1, 0))
167
+
168
+ # the Q block is kept in the shared memory throughout the whole kernel
169
+ # [BT, BK]
170
+ b_q = tl.zeros([BT, BK], dtype=tl.float32)
171
+ b_q += tl.load(p_q, boundary_check=(0, 1))
172
+
173
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
174
+ p_o = tl.make_block_ptr(o + i_bh * T*V, (T, V), (V, 1), (i_t * BT, 0), (BT, BV), (1, 0))
175
+ b_o += tl.load(p_o, boundary_check=(0, 1))
176
+
177
+ # As opposed to Flashattention, this kernel requires scanning the KV blocks from right to left
178
+ # Q block and K block have overlap.
179
+ # masks required
180
+ for offset in range((i_t + 1) * BT - 2 * BS, i_t * BT - BS, -BS):
181
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (0, offset), (BK, BS), (0, 1))
182
+ p_k2 = tl.make_block_ptr(k2 + i_bh * T*K, (T, K), (K, 1), (offset, 0), (BS, BK), (1, 0))
183
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (offset, 0), (BS, BV), (1, 0))
184
+ p_beta = tl.make_block_ptr(beta + i_bh * T, (T, ), (1, ), (offset, ), (BS, ), (0,))
185
+ # [BK, BS]
186
+ b_k = tl.load(p_k, boundary_check=(0, 1))
187
+ # [BS, BV]
188
+ b_v = tl.load(p_v, boundary_check=(0, 1))
189
+ # [BS]
190
+ b_beta = tl.load(p_beta, boundary_check=(0,))
191
+ # [BT, BS]
192
+ m_s = tl.arange(0, BT) >= (offset - i_t*BT + BS)
193
+ b_s = tl.dot(b_q.to(b_k.dtype), b_k, allow_tf32=False)
194
+ b_s = tl.where(m_s[:, None], b_s, 0)
195
+
196
+ b_o += tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)
197
+ b_k2 = (tl.load(p_k2, boundary_check=(0, 1)) * b_beta[:, None]).to(b_v.dtype)
198
+ b_q -= tl.dot(b_s.to(b_v.dtype), b_k2, allow_tf32=False)
199
+
200
+ if OUTPUT_ATTENTIONS:
201
+ p_a = tl.make_block_ptr(attn + i_bh * T * T, (T, T), (T, 1), (i_t * BT, offset), (BT, BS), (1, 0))
202
+ tl.store(p_a, b_s.to(p_a.dtype.element_ty), boundary_check=(0, 1))
203
+
204
+ # Q block and K block have no overlap
205
+ # no need for mask, thereby saving flops
206
+ for offset in range(i_t * BT - BS, -BS, -BS):
207
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (0, offset), (BK, BS), (0, 1))
208
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (offset, 0), (BS, BV), (1, 0))
209
+ p_beta = tl.make_block_ptr(beta + i_bh * T, (T, ), (1, ), (offset, ), (BS, ), (0,))
210
+ p_k2 = tl.make_block_ptr(k2 + i_bh * T*K, (T, K), (K, 1), (offset, 0), (BS, BK), (1, 0))
211
+
212
+ # [BK, BS]
213
+ b_k = tl.load(p_k, boundary_check=(0, 1))
214
+ # [BS, BV]
215
+ b_v = tl.load(p_v, boundary_check=(0, 1))
216
+ # [BS]
217
+ b_beta = tl.load(p_beta, boundary_check=(0,))
218
+ # [BT, BS]
219
+ b_s = (tl.dot(b_q.to(b_k.dtype), b_k, allow_tf32=False))
220
+ # [BT, BV]
221
+ b_o += tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)
222
+ b_k2 = (tl.load(p_k2, boundary_check=(0, 1)) * b_beta[:, None]).to(b_v.dtype)
223
+ b_q -= tl.dot(b_s.to(b_v.dtype), b_k2, allow_tf32=False).to(b_q.dtype)
224
+
225
+ if OUTPUT_ATTENTIONS:
226
+ p_a = tl.make_block_ptr(attn + i_bh * T * T, (T, T), (T, 1), (i_t * BT, offset), (BT, BS), (1, 0))
227
+ tl.store(p_a, b_s.to(p_a.dtype.element_ty), boundary_check=(0, 1))
228
+
229
+ p_o_new = tl.make_block_ptr(o_new + i_bh * T*V, (T, V), (V, 1), (i_t*BT, 0), (BT, BV), (1, 0))
230
+ tl.store(p_o_new, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
231
+
232
+
233
+ class ParallelDeltaRuleFunction(torch.autograd.Function):
234
+
235
+ @staticmethod
236
+ @input_guard
237
+ @autocast_custom_fwd
238
+ def forward(ctx, q, k, v, beta, scale, output_attentions):
239
+ B, H, T, K, V = *k.shape, v.shape[-1]
240
+ assert q.shape[-1] <= 128, 'The maximum supported sequence length is 128.'
241
+ BT, BS = 128, 32
242
+ BK = triton.next_power_of_2(k.shape[-1])
243
+ BV = triton.next_power_of_2(v.shape[-1])
244
+ assert BT % BS == 0
245
+
246
+ A = fwd_prepare_T(k, beta, BS)
247
+ attn = q.new_zeros(B, H, T, T) if output_attentions else None
248
+ q_new, k_new, o, A_local = chunk_transform_qk_fwd(
249
+ q,
250
+ k,
251
+ v,
252
+ beta,
253
+ A,
254
+ scale,
255
+ BS,
256
+ output_attentions
257
+ )
258
+
259
+ num_stages = 3 if K <= 64 else 2
260
+ num_warps = 4
261
+ grid = (triton.cdiv(T, BT), B * H)
262
+ o_new = torch.empty_like(o)
263
+
264
+ parallel_delta_rule_fwd_kernel[grid](
265
+ q=q_new,
266
+ k=k_new,
267
+ k2=k,
268
+ v=v,
269
+ beta=beta,
270
+ o=o,
271
+ o_new=o_new,
272
+ attn=attn,
273
+ T=T,
274
+ K=K,
275
+ V=V,
276
+ BT=BT,
277
+ BS=BS,
278
+ BK=BK,
279
+ BV=BV,
280
+ num_stages=num_stages,
281
+ num_warps=num_warps
282
+ )
283
+
284
+ if output_attentions:
285
+ grid = (triton.cdiv(T, BS), B * H)
286
+ save_intra_chunk_attn[grid](
287
+ A=attn,
288
+ A_local=A_local,
289
+ T=T,
290
+ BT=BS
291
+ )
292
+ return o_new.to(q.dtype), attn
293
+
294
+ @staticmethod
295
+ @input_guard
296
+ @autocast_custom_bwd
297
+ def backward(ctx, do, d_attn=None):
298
+ raise NotImplementedError('Backward pass is not implemented. Stay tuned!')
299
+
300
+
301
+ def parallel_delta_rule(
302
+ q: torch.Tensor,
303
+ k: torch.Tensor,
304
+ v: torch.Tensor,
305
+ beta: torch.Tensor,
306
+ scale: float = None,
307
+ output_attentions: bool = False,
308
+ head_first: bool = True
309
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
310
+ r"""
311
+ Args:
312
+ q (torch.Tensor):
313
+ queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
314
+ k (torch.Tensor):
315
+ keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
316
+ v (torch.Tensor):
317
+ values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
318
+ beta (torch.Tensor):
319
+ betas of shape `[B, H, T]` if `head_first=True` else `[B, T, H]`.
320
+ scale (Optional[int]):
321
+ Scale factor for attention scores.
322
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
323
+ output_attentions (bool):
324
+ Whether to output the materialized attention scores of shape [B, H, T, T]. Default: `False`.
325
+ head_first (Optional[bool]):
326
+ Whether the inputs are in the head-first format.
327
+ Default: `True`.
328
+
329
+ Returns:
330
+ o (torch.Tensor):
331
+ Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
332
+ attn (torch.Tensor):
333
+ Attention scores of shape `[B, H, T, T]` if `output_attentions=True` else `None`.
334
+ """
335
+ if not head_first:
336
+ q, k, v, beta = map(lambda x: x.transpose(1, 2), (q, k, v, beta))
337
+ o, attn = ParallelDeltaRuleFunction.apply(q, k, v, beta, scale, output_attentions)
338
+ if not head_first:
339
+ o = o.transpose(1, 2)
340
+ return o, attn
341
+
342
+
343
+ def naive_delta_rule_parallel(q, k, v, beta, BM=128, BN=32):
344
+ b, h, l, d_k = q.shape
345
+ q = q * (d_k ** -0.5)
346
+ v = v * beta[..., None]
347
+ k_beta = k * beta[..., None]
348
+ # compute (I - tri(diag(beta) KK^T))^{-1}
349
+ q, k, v, k_beta = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=BN), [q, k, v, k_beta])
350
+ mask = torch.triu(torch.ones(BN, BN, dtype=torch.bool, device=q.device), diagonal=0)
351
+ T = -(k_beta @ k.transpose(-1, -2)).masked_fill(mask, 0)
352
+ for i in range(1, BN):
353
+ T[..., i, :i] = T[..., i, :i].clone() + (T[..., i, :, None].clone() * T[..., :, :i].clone()).sum(-2)
354
+ T = T + torch.eye(BN, dtype=q.dtype, device=q.device)
355
+
356
+ mask2 = torch.triu(torch.ones(BN, BN, dtype=torch.bool, device=q.device), diagonal=1)
357
+ A_local = (q @ k.transpose(-1, -2)).masked_fill(mask2, 0) @ T
358
+ o_intra = A_local @ v
359
+
360
+ # apply cumprod transition matrices on k to the last position within the chunk
361
+ k = k - ((k @ k.transpose(-1, -2)).masked_fill(mask, 0) @ T).transpose(-1, -2) @ k_beta
362
+ # apply cumprod transition matrices on q to the first position within the chunk
363
+ q = q - A_local @ k_beta
364
+ o_intra = A_local @ v
365
+
366
+ A = torch.zeros(b, h, l, l, device=q.device)
367
+
368
+ q, k, v, k_beta, o_intra = map(lambda x: rearrange(x, 'b h n c d -> b h (n c) d'), [q, k, v, k_beta, o_intra])
369
+ o = torch.empty_like(v)
370
+ for i in range(0, l, BM):
371
+ q_i = q[:, :, i:i+BM]
372
+ o_i = o_intra[:, :, i:i+BM]
373
+ # intra block
374
+ for j in range(i + BM - 2 * BN, i-BN, -BN):
375
+ k_j = k[:, :, j:j+BN]
376
+ A_ij = q_i @ k_j.transpose(-1, -2)
377
+ mask = torch.arange(i, i+BM) >= (j + BN)
378
+ A_ij = A_ij.masked_fill_(~mask[:, None].to(A_ij.device), 0)
379
+ A[:, :, i:i+BM, j:j+BN] = A_ij
380
+ q_i = q_i - A_ij @ k_beta[:, :, j:j+BN]
381
+ o_i += A_ij @ v[:, :, j:j+BN]
382
+ # inter block
383
+ for j in range(i - BN, -BN, -BN):
384
+ k_j = k[:, :, j:j+BN]
385
+ A_ij = q_i @ k_j.transpose(-1, -2)
386
+ A[:, :, i:i+BM, j:j+BN] = A_ij
387
+ q_i = q_i - A_ij @ k_beta[:, :, j:j+BN]
388
+ o_i += A_ij @ v[:, :, j:j+BN]
389
+ o[:, :, i:i+BM] = o_i
390
+
391
+ for i in range(0, l//BN):
392
+ A[:, :, i*BN:i*BN+BN, i*BN:i*BN+BN] = A_local[:, :, i]
393
+
394
+ return o, A
fla/ops/gated_delta_rule/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .chunk import chunk_gated_delta_rule
2
+ from .fused_recurrent import fused_recurrent_gated_delta_rule
3
+
4
+ __all__ = [
5
+ "chunk_gated_delta_rule",
6
+ "fused_recurrent_gated_delta_rule"
7
+ ]
fla/ops/gated_delta_rule/__pycache__/chunk.cpython-312.pyc ADDED
Binary file (14.4 kB). View file
 
fla/ops/gated_delta_rule/__pycache__/fused_recurrent.cpython-312.pyc ADDED
Binary file (15.1 kB). View file
 
fla/ops/generalized_delta_rule/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .dplr import chunk_dplr_delta_rule, fused_recurrent_dplr_delta_rule
2
+ from .iplr import chunk_iplr_delta_rule, fused_recurrent_iplr_delta_rule
3
+
4
+ __all__ = [
5
+ 'chunk_dplr_delta_rule',
6
+ 'fused_recurrent_dplr_delta_rule',
7
+ 'chunk_iplr_delta_rule',
8
+ 'fused_recurrent_iplr_delta_rule'
9
+ ]
fla/ops/generalized_delta_rule/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (389 Bytes). View file
 
fla/ops/generalized_delta_rule/dplr/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .chunk import chunk_dplr_delta_rule
2
+ from .fused_recurrent import fused_recurrent_dplr_delta_rule
3
+
4
+ __all__ = [
5
+ 'chunk_dplr_delta_rule',
6
+ 'fused_recurrent_dplr_delta_rule'
7
+ ]
fla/ops/generalized_delta_rule/dplr/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (328 Bytes). View file
 
fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_A_bwd.cpython-312.pyc ADDED
Binary file (30.6 kB). View file
 
fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_h_bwd.cpython-312.pyc ADDED
Binary file (12.2 kB). View file
 
fla/ops/generalized_delta_rule/dplr/__pycache__/wy_fast_fwd.cpython-312.pyc ADDED
Binary file (21.3 kB). View file
 
fla/ops/generalized_delta_rule/dplr/chunk_A_bwd.py ADDED
@@ -0,0 +1,446 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.utils.op import exp, gather
11
+ from fla.utils import check_shared_mem, is_gather_supported, use_cuda_graph
12
+
13
+
14
+ @triton.heuristics({
15
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
16
+ })
17
+ @triton.autotune(
18
+ configs=[
19
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
20
+ for num_warps in [2, 4, 8, 16, 32]
21
+ for num_stages in [2, 3, 4]
22
+ ],
23
+ key=['BK', 'NC', 'BT', 'K'],
24
+ use_cuda_graph=use_cuda_graph,
25
+ )
26
+ @triton.jit(do_not_specialize=['T'])
27
+ def chunk_dplr_bwd_kernel_intra(
28
+ q,
29
+ k,
30
+ a,
31
+ b,
32
+ gi,
33
+ ge,
34
+ dAqk,
35
+ dAqb,
36
+ dAak,
37
+ dAab,
38
+ dq,
39
+ dk,
40
+ da,
41
+ db,
42
+ dqg,
43
+ dkg,
44
+ dag,
45
+ dbg,
46
+ dgk,
47
+ dgk_offset,
48
+ offsets,
49
+ indices,
50
+ scale: tl.constexpr,
51
+ T,
52
+ H: tl.constexpr,
53
+ K: tl.constexpr,
54
+ BT: tl.constexpr,
55
+ BC: tl.constexpr,
56
+ BK: tl.constexpr,
57
+ NC: tl.constexpr,
58
+ USE_OFFSETS: tl.constexpr,
59
+ HEAD_FIRST: tl.constexpr,
60
+ GATHER_SUPPORTED: tl.constexpr
61
+ ):
62
+ i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
63
+ i_b, i_h = i_bh // H, i_bh % H
64
+ i_t, i_i = i_c // NC, i_c % NC
65
+ if USE_OFFSETS:
66
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
67
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
68
+ else:
69
+ bos, eos = i_b * T, i_b * T + T
70
+ T = eos - bos
71
+ if i_t * BT + i_i * BC >= T:
72
+ return
73
+
74
+ # offset calculation
75
+ ge += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K
76
+ gi += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K
77
+ q += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K
78
+ a += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K
79
+ b += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K
80
+ k += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K
81
+ dq += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K
82
+ dk += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K
83
+ da += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K
84
+ db += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K
85
+ dqg += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K
86
+ dag += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K
87
+ dkg += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K
88
+ dbg += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K
89
+ dgk += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K
90
+ dgk_offset += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K
91
+ dAqk += i_bh * T * BT if HEAD_FIRST else (bos*H + i_h) * BT
92
+ dAqb += i_bh * T * BT if HEAD_FIRST else (bos*H + i_h) * BT
93
+ dAak += i_bh * T * BT if HEAD_FIRST else (bos*H + i_h) * BT
94
+ dAab += i_bh * T * BT if HEAD_FIRST else (bos*H + i_h) * BT
95
+
96
+ stride_qk = K if HEAD_FIRST else H*K
97
+ stride_A = BT if HEAD_FIRST else H*BT
98
+
99
+ p_ge = tl.make_block_ptr(ge, (T, K), (stride_qk, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
100
+ p_gi = tl.make_block_ptr(gi, (T, K), (stride_qk, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
101
+ # [BC, BK]
102
+ b_ge = tl.load(p_ge, boundary_check=(0, 1))
103
+ b_gi = tl.load(p_gi, boundary_check=(0, 1))
104
+ b_dq = tl.zeros([BC, BK], dtype=tl.float32)
105
+ b_da = tl.zeros([BC, BK], dtype=tl.float32)
106
+ b_dk = tl.zeros([BC, BK], dtype=tl.float32)
107
+ b_db = tl.zeros([BC, BK], dtype=tl.float32)
108
+ # intra chunk gradient calculation
109
+ p_dAqk = tl.make_block_ptr(dAqk, (T, BT), (stride_A, 1), (i_t*BT + i_i*BC, i_i*BC), (BC, BC), (1, 0))
110
+ p_dAab = tl.make_block_ptr(dAab, (T, BT), (stride_A, 1), (i_t*BT + i_i*BC, i_i*BC), (BC, BC), (1, 0))
111
+ p_dAqb = tl.make_block_ptr(dAqb, (T, BT), (stride_A, 1), (i_t*BT + i_i*BC, i_i*BC), (BC, BC), (1, 0))
112
+ p_dAak = tl.make_block_ptr(dAak, (T, BT), (stride_A, 1), (i_t*BT + i_i*BC, i_i*BC), (BC, BC), (1, 0))
113
+ o_i = tl.arange(0, BC)
114
+ p_k = tl.make_block_ptr(k, (T, K), (stride_qk, 1), (i_t*BT + i_i*BC, i_k*BK), (BC, BK), (1, 0))
115
+ p_b = tl.make_block_ptr(b, (T, K), (stride_qk, 1), (i_t*BT + i_i*BC, i_k*BK), (BC, BK), (1, 0))
116
+ p_a = tl.make_block_ptr(a, (T, K), (stride_qk, 1), (i_t*BT + i_i*BC, i_k*BK), (BC, BK), (1, 0))
117
+ p_q = tl.make_block_ptr(q, (T, K), (stride_qk, 1), (i_t*BT + i_i*BC, i_k*BK), (BC, BK), (1, 0))
118
+ b_k = tl.load(p_k, boundary_check=(0, 1)).to(tl.float32)
119
+ b_b = tl.load(p_b, boundary_check=(0, 1)).to(tl.float32)
120
+ b_q = tl.load(p_q, boundary_check=(0, 1)).to(tl.float32)
121
+ b_a = tl.load(p_a, boundary_check=(0, 1)).to(tl.float32)
122
+ b_dAqk = tl.load(p_dAqk, boundary_check=(0, 1)).to(tl.float32)
123
+ b_dAab = tl.load(p_dAab, boundary_check=(0, 1)).to(tl.float32)
124
+ b_dAqb = tl.load(p_dAqb, boundary_check=(0, 1)).to(tl.float32)
125
+ b_dAak = tl.load(p_dAak, boundary_check=(0, 1)).to(tl.float32)
126
+
127
+ # inter chunk gradient calculation
128
+ o_k = i_k * BK + tl.arange(0, BK)
129
+ m_k = o_k < K
130
+ if i_i > 0:
131
+ p_gn = gi + (i_t * BT + i_i * BC - 1) * stride_qk + o_k
132
+ p_gn = tl.max_contiguous(tl.multiple_of(p_gn, BK), BK)
133
+ # [BK,]
134
+ b_gn = tl.load(p_gn, mask=m_k, other=0)
135
+ # [BK,]
136
+ for i_j in range(0, i_i):
137
+ p_kj = tl.make_block_ptr(k, (T, K), (stride_qk, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
138
+ p_bj = tl.make_block_ptr(b, (T, K), (stride_qk, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
139
+ p_gkj = tl.make_block_ptr(gi, (T, K), (stride_qk, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
140
+ p_dAqikj = tl.make_block_ptr(dAqk, (T, BT), (stride_A, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
141
+ p_dAaibj = tl.make_block_ptr(dAab, (T, BT), (stride_A, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
142
+ p_dAqibj = tl.make_block_ptr(dAqb, (T, BT), (stride_A, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
143
+ p_dAaikj = tl.make_block_ptr(dAak, (T, BT), (stride_A, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
144
+ # [BC, BK]
145
+ b_kj = tl.load(p_kj, boundary_check=(0, 1))
146
+ b_bj = tl.load(p_bj, boundary_check=(0, 1))
147
+ b_gkj = tl.load(p_gkj, boundary_check=(0, 1))
148
+ tmp = exp(b_gn[None, :] - b_gkj)
149
+ b_kjg = b_kj * tmp
150
+ b_bjg = b_bj * tmp
151
+ # [BC, BC]
152
+ b_dAqikj = tl.load(p_dAqikj, boundary_check=(0, 1))
153
+ b_dAaibj = tl.load(p_dAaibj, boundary_check=(0, 1))
154
+ b_dAqibj = tl.load(p_dAqibj, boundary_check=(0, 1))
155
+ b_dAaikj = tl.load(p_dAaikj, boundary_check=(0, 1))
156
+ # [BC, BK]
157
+ b_dq += tl.dot(b_dAqikj, b_kjg)
158
+ b_dq += tl.dot(b_dAqibj, b_bjg)
159
+ # [BC, BC]
160
+ b_da += tl.dot(b_dAaibj, b_bjg)
161
+ b_da += tl.dot(b_dAaikj, b_kjg)
162
+ b_dq *= exp(b_gi - b_gn[None, :])
163
+ b_da *= exp(b_ge - b_gn[None, :])
164
+
165
+ NC = min(NC, tl.cdiv(T - i_t * BT, BC))
166
+ if i_i < NC - 1:
167
+ p_gn = gi + (min(i_t * BT + i_i * BC + BC, T) - 1)*stride_qk + o_k
168
+ p_gn = tl.max_contiguous(tl.multiple_of(p_gn, BK), BK)
169
+ # [BK,]
170
+ b_gn = tl.load(p_gn, mask=m_k, other=0)
171
+ for i_j in range(i_i + 1, NC):
172
+ m_j = (i_t * BT + i_j * BC + tl.arange(0, BC)) < T
173
+ p_qj = tl.make_block_ptr(q, (T, K), (stride_qk, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
174
+ p_aj = tl.make_block_ptr(a, (T, K), (stride_qk, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
175
+ p_gij = tl.make_block_ptr(gi, (T, K), (stride_qk, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
176
+ p_gej = tl.make_block_ptr(ge, (T, K), (stride_qk, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0))
177
+ p_dAqjki = tl.make_block_ptr(dAqk, (BT, T), (1, stride_A), (i_i*BC, i_t*BT + i_j*BC), (BC, BC), (0, 1))
178
+ p_dAajbi = tl.make_block_ptr(dAab, (BT, T), (1, stride_A), (i_i*BC, i_t*BT + i_j*BC), (BC, BC), (0, 1))
179
+ p_dAqjbi = tl.make_block_ptr(dAqb, (BT, T), (1, stride_A), (i_i*BC, i_t*BT + i_j*BC), (BC, BC), (0, 1))
180
+ p_dAajki = tl.make_block_ptr(dAak, (BT, T), (1, stride_A), (i_i*BC, i_t*BT + i_j*BC), (BC, BC), (0, 1))
181
+ b_qj = tl.load(p_qj, boundary_check=(0, 1))
182
+ b_aj = tl.load(p_aj, boundary_check=(0, 1))
183
+ b_gij = tl.load(p_gij, boundary_check=(0, 1))
184
+ b_gej = tl.load(p_gej, boundary_check=(0, 1))
185
+ b_gij = tl.where(m_j[:, None] & m_k, b_gij, float('-inf'))
186
+ b_gej = tl.where(m_j[:, None] & m_k, b_gej, float('-inf'))
187
+ b_qjg = b_qj * exp(b_gij - b_gn[None, :])
188
+ b_ajg = b_aj * exp(b_gej - b_gn[None, :])
189
+ # [BC, BC]
190
+ b_dAqjki = tl.load(p_dAqjki, boundary_check=(0, 1))
191
+ b_dAajbi = tl.load(p_dAajbi, boundary_check=(0, 1))
192
+ b_dAqjbi = tl.load(p_dAqjbi, boundary_check=(0, 1))
193
+ b_dAajki = tl.load(p_dAajki, boundary_check=(0, 1))
194
+ b_dk += tl.dot(b_dAqjki, b_qjg)
195
+ b_dk += tl.dot(b_dAajki, b_ajg)
196
+ b_db += tl.dot(b_dAqjbi, b_qjg)
197
+ b_db += tl.dot(b_dAajbi, b_ajg)
198
+ tmp = exp(b_gn[None, :] - b_gi)
199
+ b_dk *= tmp
200
+ b_db *= tmp
201
+
202
+ # intra chunk gradient calculation
203
+ for j in range(0, min(BC, T - i_t * BT - i_i * BC)):
204
+ # trick to index the block
205
+ if GATHER_SUPPORTED:
206
+ row_idx = tl.full([1, BK], j, dtype=tl.int16)
207
+ col_idx = tl.full([BC, 1], j, dtype=tl.int16)
208
+ row_idx_bc = tl.full([1, BC], j, dtype=tl.int16)
209
+ # [1, BK]
210
+ b_kj = gather(b_k, row_idx, axis=0)
211
+ b_bj = gather(b_b, row_idx, axis=0)
212
+ b_gij = gather(b_gi, row_idx, axis=0)
213
+ b_gej = gather(b_ge, row_idx, axis=0)
214
+ b_qj = gather(b_q, row_idx, axis=0)
215
+ b_aj = gather(b_a, row_idx, axis=0)
216
+ # [BC, 1]
217
+ b_dAqk_j = gather(b_dAqk, col_idx, axis=1)
218
+ b_dAab_j = gather(b_dAab, col_idx, axis=1)
219
+ b_dAqb_j = gather(b_dAqb, col_idx, axis=1)
220
+ b_dAak_j = gather(b_dAak, col_idx, axis=1)
221
+ # [1, BC] -> [BC, 1]
222
+ b_dA_qk_j = tl.sum(gather(b_dAqk, row_idx_bc, axis=0), 0)[:, None]
223
+ b_dA_qk_j = tl.sum(gather(b_dAqk, row_idx_bc, axis=0), 0)[:, None]
224
+ b_dA_ab_j = tl.sum(gather(b_dAab, row_idx_bc, axis=0), 0)[:, None]
225
+ b_dA_qb_j = tl.sum(gather(b_dAqb, row_idx_bc, axis=0), 0)[:, None]
226
+ b_dA_ak_j = tl.sum(gather(b_dAak, row_idx_bc, axis=0), 0)[:, None]
227
+ else:
228
+ mask_idx = tl.arange(0, BC) == j
229
+ b_kj = tl.sum(tl.where(mask_idx[:, None], b_k, 0), 0)[None, :]
230
+ b_bj = tl.sum(tl.where(mask_idx[:, None], b_b, 0), 0)[None, :]
231
+ b_gij = tl.sum(tl.where(mask_idx[:, None], b_gi, 0), 0)[None, :]
232
+ b_gej = tl.sum(tl.where(mask_idx[:, None], b_ge, 0), 0)[None, :]
233
+ b_dAqk_j = tl.sum(tl.where(mask_idx[None, :], b_dAqk, 0), 1)[:, None]
234
+ b_dAab_j = tl.sum(tl.where(mask_idx[None, :], b_dAab, 0), 1)[:, None]
235
+ b_dAqb_j = tl.sum(tl.where(mask_idx[None, :], b_dAqb, 0), 1)[:, None]
236
+ b_dAak_j = tl.sum(tl.where(mask_idx[None, :], b_dAak, 0), 1)[:, None]
237
+ b_dA_qk_j = tl.sum(tl.where(mask_idx[:, None], b_dAqk, 0), 0)[:, None]
238
+ b_dA_ab_j = tl.sum(tl.where(mask_idx[:, None], b_dAab, 0), 0)[:, None]
239
+ b_dA_qb_j = tl.sum(tl.where(mask_idx[:, None], b_dAqb, 0), 0)[:, None]
240
+ b_dA_ak_j = tl.sum(tl.where(mask_idx[:, None], b_dAak, 0), 0)[:, None]
241
+ # [1, BK] b_qj, b_aj
242
+ b_qj = tl.sum(tl.where(mask_idx[:, None], b_q, 0), 0)[None, :]
243
+ b_aj = tl.sum(tl.where(mask_idx[:, None], b_a, 0), 0)[None, :]
244
+ # tl.static_print(b_kj)
245
+ m_e = o_i[:, None] > j
246
+ m_i = o_i[:, None] >= j
247
+ tmp1 = exp(b_gi - b_gij)
248
+ tmp2 = exp(b_ge - b_gij)
249
+ b_dq += tl.where(m_i, b_dAqk_j * b_kj * tmp1, 0.)
250
+ b_dq += tl.where(m_i, b_dAqb_j * b_bj * tmp1, 0.)
251
+ b_da += tl.where(m_e, b_dAab_j * b_bj * tmp2, 0.)
252
+ b_da += tl.where(m_e, b_dAak_j * b_kj * tmp2, 0.)
253
+
254
+ m_i = o_i[:, None] <= j
255
+ m_e = o_i[:, None] < j
256
+ tmp1 = exp(b_gij - b_gi)
257
+ tmp2 = exp(b_gej - b_gi)
258
+ b_dk += tl.where(m_i, b_dA_qk_j * b_qj * tmp1, 0.)
259
+ b_dk += tl.where(m_e, b_dA_ak_j * b_aj * tmp2, 0.)
260
+ b_db += tl.where(m_i, b_dA_qb_j * b_qj * tmp1, 0.)
261
+ b_db += tl.where(m_e, b_dA_ab_j * b_aj * tmp2, 0.)
262
+ # post processing
263
+ p_dq = tl.make_block_ptr(dq, (T, K), (stride_qk, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
264
+ p_dk = tl.make_block_ptr(dk, (T, K), (stride_qk, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
265
+ p_da = tl.make_block_ptr(da, (T, K), (stride_qk, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
266
+ p_db = tl.make_block_ptr(db, (T, K), (stride_qk, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
267
+ p_dgk = tl.make_block_ptr(dgk, (T, K), (stride_qk, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
268
+ p_dgk_offset = tl.make_block_ptr(dgk_offset, (T, K), (stride_qk, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
269
+ p_dqg = tl.make_block_ptr(dqg, (T, K), (stride_qk, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
270
+ p_dkg = tl.make_block_ptr(dkg, (T, K), (stride_qk, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
271
+ p_dag = tl.make_block_ptr(dag, (T, K), (stride_qk, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
272
+ p_dbg = tl.make_block_ptr(dbg, (T, K), (stride_qk, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
273
+ p_gn = gi + (min(i_t * BT + BT, T) - 1)*stride_qk + o_k
274
+ p_gn = tl.max_contiguous(tl.multiple_of(p_gn, BK), BK)
275
+ b_gn = tl.load(p_gn, mask=m_k, other=0)
276
+ b_da += tl.load(p_dag, boundary_check=(0, 1)) * exp(b_ge)
277
+ b_dq += tl.load(p_dqg, boundary_check=(0, 1)) * exp(b_gi) * scale
278
+ tmp = exp(b_gn[None, :] - b_gi)
279
+ b_dk += tl.load(p_dkg, boundary_check=(0, 1)) * tmp
280
+ b_db += tl.load(p_dbg, boundary_check=(0, 1)) * tmp
281
+ tl.store(p_dq, (b_dq).to(p_dq.dtype.element_ty), boundary_check=(0, 1))
282
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
283
+ tl.store(p_da, b_da.to(p_da.dtype.element_ty), boundary_check=(0, 1))
284
+ tl.store(p_db, b_db.to(p_db.dtype.element_ty), boundary_check=(0, 1))
285
+ b_dgk = b_dq * b_q + b_da * b_a - b_dk * b_k - b_db * b_b
286
+ b_dgk_offset = b_da * b_a
287
+ tl.store(p_dgk, b_dgk.to(p_dgk.dtype.element_ty), boundary_check=(0, 1))
288
+ tl.store(p_dgk_offset, b_dgk_offset.to(p_dgk_offset.dtype.element_ty), boundary_check=(0, 1))
289
+
290
+
291
+ @triton.heuristics({
292
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
293
+ })
294
+ @triton.autotune(
295
+ configs=[
296
+ triton.Config({'BK': BK}, num_warps=num_warps, num_stages=num_stages)
297
+ for num_warps in [2, 4, 8, 16, 32]
298
+ for num_stages in [2, 3, 4]
299
+ for BK in [32, 64]
300
+ ],
301
+ key=['BK', 'BT', 'K'],
302
+ use_cuda_graph=use_cuda_graph,
303
+ )
304
+ @triton.jit(do_not_specialize=['T'])
305
+ def chunk_dplr_bwd_dgk_kernel(
306
+ dgk,
307
+ dgk_offset,
308
+ dgk_last,
309
+ dgk_output,
310
+ offsets,
311
+ indices,
312
+ T,
313
+ H: tl.constexpr,
314
+ K: tl.constexpr,
315
+ BT: tl.constexpr,
316
+ BK: tl.constexpr,
317
+ USE_OFFSETS: tl.constexpr,
318
+ HEAD_FIRST: tl.constexpr,
319
+ ):
320
+ i_t, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
321
+ i_b, i_h = i_bh // H, i_bh % H
322
+ if USE_OFFSETS:
323
+ i_tg = i_t
324
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
325
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
326
+ T = eos - bos
327
+ NT = tl.cdiv(T, BT)
328
+ else:
329
+ NT = tl.cdiv(T, BT)
330
+ i_tg = i_b * NT + i_t
331
+ bos, eos = i_b * T, i_b * T + T
332
+ T = eos - bos
333
+ stride_qk = K if HEAD_FIRST else H * K
334
+ dgk += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K
335
+ dgk_offset += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K
336
+ dgk_last += ((i_bh * NT + i_t) * K) if HEAD_FIRST else (i_tg * H + i_h) * K
337
+ dgk_output += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K
338
+ p_dgk_last = dgk_last + tl.arange(0, BK) + i_k * BK
339
+ m_k = tl.arange(0, BK) + i_k * BK < K
340
+ b_dgk_last = tl.load(p_dgk_last, mask=m_k, other=0)
341
+ p_dgk_offset = tl.make_block_ptr(dgk_offset, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
342
+ p_dgk = tl.make_block_ptr(dgk, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
343
+ b_dgk = tl.load(p_dgk, boundary_check=(0, 1))
344
+ b_dgk_offset = tl.load(p_dgk_offset, boundary_check=(0, 1))
345
+ # m_inv_cumsum = (tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :]).to(tl.float32)
346
+ # b_dgk_cumsum = tl.dot(m_inv_cumsum, b_dgk, allow_tf32=False)
347
+ b_dgk_cumsum = tl.cumsum(b_dgk, 0, reverse=True)
348
+ b_dgk_cumsum += b_dgk_last[None, :]
349
+ b_dgk_cumsum -= b_dgk_offset
350
+ p_dgk_output = tl.make_block_ptr(dgk_output, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
351
+ tl.store(p_dgk_output, b_dgk_cumsum.to(p_dgk_output.dtype.element_ty), boundary_check=(0, 1))
352
+
353
+
354
+ def chunk_dplr_bwd_dqk_intra(
355
+ q: torch.Tensor,
356
+ k: torch.Tensor,
357
+ a: torch.Tensor,
358
+ b: torch.Tensor,
359
+ gi: torch.Tensor,
360
+ ge: torch.Tensor,
361
+ dAqk: torch.Tensor,
362
+ dAqb: torch.Tensor,
363
+ dAak: torch.Tensor,
364
+ dAab: torch.Tensor,
365
+ dqg: torch.Tensor,
366
+ dkg: torch.Tensor,
367
+ dag: torch.Tensor,
368
+ dbg: torch.Tensor,
369
+ dgk_last: torch.Tensor,
370
+ offsets: Optional[torch.LongTensor] = None,
371
+ indices: Optional[torch.LongTensor] = None,
372
+ head_first: bool = True,
373
+ scale: float = 1.0,
374
+ chunk_size: int = 64,
375
+ ):
376
+ if head_first:
377
+ B, H, T, K = q.shape
378
+ else:
379
+ B, T, H, K = q.shape
380
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
381
+ BC = min(16, BT)
382
+ BK = min(64, triton.next_power_of_2(K)) if check_shared_mem() else min(32, triton.next_power_of_2(K))
383
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
384
+ NC = triton.cdiv(BT, BC)
385
+ NK = triton.cdiv(K, BK)
386
+
387
+ dq = torch.empty_like(q)
388
+ dk = torch.empty_like(k)
389
+ da = torch.empty_like(a)
390
+ db = torch.empty_like(b)
391
+ dgk = torch.empty_like(gi, dtype=torch.float)
392
+ dgk_offset = torch.empty_like(gi, dtype=torch.float)
393
+
394
+ grid = (NK, NT * NC, B * H)
395
+ chunk_dplr_bwd_kernel_intra[grid](
396
+ q=q,
397
+ k=k,
398
+ a=a,
399
+ b=b,
400
+ gi=gi,
401
+ ge=ge,
402
+ dAqk=dAqk,
403
+ dAqb=dAqb,
404
+ dAak=dAak,
405
+ dAab=dAab,
406
+ dq=dq,
407
+ dk=dk,
408
+ dgk=dgk,
409
+ dgk_offset=dgk_offset,
410
+ dqg=dqg,
411
+ dkg=dkg,
412
+ dag=dag,
413
+ dbg=dbg,
414
+ da=da,
415
+ db=db,
416
+ offsets=offsets,
417
+ indices=indices,
418
+ scale=scale,
419
+ T=T,
420
+ H=H,
421
+ K=K,
422
+ BT=BT,
423
+ BC=BC,
424
+ BK=BK,
425
+ NC=NC,
426
+ HEAD_FIRST=head_first,
427
+ GATHER_SUPPORTED=is_gather_supported
428
+ )
429
+
430
+ def grid2(meta): return (NT, triton.cdiv(K, meta['BK']), B * H)
431
+ dgk_output = torch.empty_like(dgk)
432
+
433
+ chunk_dplr_bwd_dgk_kernel[grid2](
434
+ dgk=dgk,
435
+ dgk_offset=dgk_offset,
436
+ dgk_last=dgk_last,
437
+ dgk_output=dgk_output,
438
+ offsets=offsets,
439
+ indices=indices,
440
+ T=T,
441
+ H=H,
442
+ K=K,
443
+ BT=BT,
444
+ HEAD_FIRST=head_first
445
+ )
446
+ return dq, dk, da, db, dgk_output
fla/ops/generalized_delta_rule/dplr/chunk_A_fwd.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.utils.op import exp, gather
11
+ from fla.utils import is_gather_supported, use_cuda_graph
12
+
13
+
14
+ @triton.heuristics({
15
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
16
+ })
17
+ @triton.autotune(
18
+ configs=[
19
+ triton.Config({'BK': BK}, num_warps=num_warps, num_stages=num_stages)
20
+ for BK in [32, 64]
21
+ for num_warps in [2, 4, 8, 16]
22
+ for num_stages in [2, 3, 4]
23
+ ],
24
+ key=['BC', 'K'],
25
+ use_cuda_graph=use_cuda_graph,
26
+ )
27
+ @triton.jit(do_not_specialize=['T'])
28
+ def chunk_dplr_fwd_A_kernel_intra_sub_inter(
29
+ q,
30
+ k,
31
+ a,
32
+ b,
33
+ gi, # cumsum
34
+ ge, # before cumsum
35
+ Aqk,
36
+ Aqb,
37
+ Aab,
38
+ Aak,
39
+ offsets,
40
+ indices,
41
+ scale: tl.constexpr,
42
+ T,
43
+ H: tl.constexpr,
44
+ K: tl.constexpr,
45
+ BT: tl.constexpr,
46
+ BC: tl.constexpr,
47
+ BK: tl.constexpr,
48
+ NC: tl.constexpr,
49
+ USE_OFFSETS: tl.constexpr,
50
+ HEAD_FIRST: tl.constexpr,
51
+ ):
52
+ i_t, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
53
+ i_b, i_h = i_bh // H, i_bh % H
54
+ i_i, i_j = i_c // NC, i_c % NC
55
+ if USE_OFFSETS:
56
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
57
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
58
+ T = eos - bos
59
+ else:
60
+ bos, eos = i_b * T, i_b * T + T
61
+
62
+ if i_t * BT + i_i * BC >= T:
63
+ return
64
+ if i_i <= i_j:
65
+ return
66
+
67
+ b_Aqk = tl.zeros([BC, BC], dtype=tl.float32)
68
+ b_Aqb = tl.zeros([BC, BC], dtype=tl.float32)
69
+ b_Aab = tl.zeros([BC, BC], dtype=tl.float32)
70
+ b_Aak = tl.zeros([BC, BC], dtype=tl.float32)
71
+ for i_k in range(tl.cdiv(K, BK)):
72
+ o_k = i_k * BK + tl.arange(0, BK)
73
+ m_k = o_k < K
74
+
75
+ if HEAD_FIRST:
76
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
77
+ p_a = tl.make_block_ptr(a + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
78
+ p_gq_i = tl.make_block_ptr(gi + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
79
+ p_gq_e = tl.make_block_ptr(ge + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
80
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
81
+ p_b = tl.make_block_ptr(b + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
82
+ p_gk = tl.make_block_ptr(gi + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
83
+ p_gn = tl.max_contiguous(tl.multiple_of(gi + (i_bh * T + i_t * BT + i_i * BC - 1) * K + o_k, BK), BK)
84
+ else:
85
+ p_q = tl.make_block_ptr(q + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
86
+ p_a = tl.make_block_ptr(a + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
87
+ p_gq_i = tl.make_block_ptr(gi + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
88
+ p_gq_e = tl.make_block_ptr(ge + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
89
+ p_k = tl.make_block_ptr(k + (bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
90
+ p_b = tl.make_block_ptr(b + (bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
91
+ p_gk = tl.make_block_ptr(gi + (bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
92
+ p_gn = gi + (bos + i_t * BT + i_i * BC - 1) * H*K + i_h * K + o_k
93
+ # [BK,]
94
+ b_gn = tl.load(p_gn, mask=m_k, other=0).to(tl.float32)
95
+ # [BC, BK]
96
+ b_q = tl.load(p_q, boundary_check=(0, 1))
97
+ b_a = tl.load(p_a, boundary_check=(0, 1))
98
+ b_gq_i = tl.load(p_gq_i, boundary_check=(0, 1))
99
+ b_gq_e = tl.load(p_gq_e, boundary_check=(0, 1))
100
+ b_ag = b_a * exp(b_gq_e - b_gn[None, :])
101
+ b_qg = b_q * exp(b_gq_i - b_gn[None, :]) * scale
102
+ # [BK, BC]
103
+ b_k = tl.load(p_k, boundary_check=(0, 1))
104
+ b_b = tl.load(p_b, boundary_check=(0, 1))
105
+ b_gk = tl.load(p_gk, boundary_check=(0, 1)).to(tl.float32)
106
+ tmp = exp(b_gn[:, None] - b_gk)
107
+ b_kg = b_k * tmp
108
+ b_bg = b_b * tmp
109
+ # [BC, BC] using tf32 to improve precision here.
110
+ b_Aab += tl.dot(b_ag, b_bg)
111
+ b_Aak += tl.dot(b_ag, b_kg)
112
+ b_Aqk += tl.dot(b_qg, b_kg)
113
+ b_Aqb += tl.dot(b_qg, b_bg)
114
+
115
+ if HEAD_FIRST:
116
+ p_Aqk = tl.make_block_ptr(Aqk + i_bh*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
117
+ p_Aqb = tl.make_block_ptr(Aqb + i_bh*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
118
+ p_Aab = tl.make_block_ptr(Aab + i_bh*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
119
+ p_Aak = tl.make_block_ptr(Aak + i_bh*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
120
+ else:
121
+ p_Aqk = tl.make_block_ptr(Aqk + (bos*H+i_h)*BT, (T, BT), (H*BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
122
+ p_Aqb = tl.make_block_ptr(Aqb + (bos*H+i_h)*BT, (T, BT), (H*BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
123
+ p_Aab = tl.make_block_ptr(Aab + (bos*H+i_h)*BT, (T, BT), (H*BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
124
+ p_Aak = tl.make_block_ptr(Aak + (bos*H+i_h)*BT, (T, BT), (H*BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
125
+ tl.store(p_Aqk, b_Aqk.to(Aqk.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
126
+ tl.store(p_Aqb, b_Aqb.to(Aqb.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
127
+ tl.store(p_Aab, b_Aab.to(Aab.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
128
+ tl.store(p_Aak, b_Aak.to(Aak.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
129
+
130
+
131
+ @triton.heuristics({
132
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
133
+ })
134
+ @triton.autotune(
135
+ configs=[
136
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
137
+ for num_warps in [2, 4, 8, 16, 32]
138
+ for num_stages in [2, 3, 4]
139
+ ],
140
+ key=['BK', 'BT'],
141
+ use_cuda_graph=use_cuda_graph,
142
+ )
143
+ @triton.jit(do_not_specialize=['T'])
144
+ def chunk_dplr_fwd_A_kernel_intra_sub_intra(
145
+ q,
146
+ k,
147
+ a,
148
+ b,
149
+ gi,
150
+ ge,
151
+ qg,
152
+ kg,
153
+ ag,
154
+ bg,
155
+ Aqk,
156
+ Aqb,
157
+ Aab,
158
+ Aak,
159
+ offsets,
160
+ indices,
161
+ scale: tl.constexpr,
162
+ T,
163
+ H: tl.constexpr,
164
+ K: tl.constexpr,
165
+ BT: tl.constexpr,
166
+ BC: tl.constexpr,
167
+ BK: tl.constexpr,
168
+ NC: tl.constexpr,
169
+ USE_OFFSETS: tl.constexpr,
170
+ HEAD_FIRST: tl.constexpr,
171
+ GATHER_SUPPORTED: tl.constexpr
172
+ ):
173
+ i_t, i_i, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
174
+ i_b, i_h = i_bh // H, i_bh % H
175
+ i_j = i_i
176
+ if USE_OFFSETS:
177
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
178
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
179
+ T = eos - bos
180
+ else:
181
+ bos, eos = i_b * T, i_b * T + T
182
+
183
+ if i_t * BT + i_i * BC >= T:
184
+ return
185
+
186
+ o_i = tl.arange(0, BC)
187
+ o_k = tl.arange(0, BK)
188
+ m_k = o_k < K
189
+ m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T
190
+ last_idx = min((i_t+1) * BT, T) - 1
191
+ if HEAD_FIRST:
192
+ o_A = i_bh * T*BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_j * BC
193
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
194
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
195
+ p_a = tl.make_block_ptr(a + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
196
+ p_b = tl.make_block_ptr(b + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
197
+ p_gi = tl.make_block_ptr(gi + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
198
+ p_ge = tl.make_block_ptr(ge + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
199
+ p_g_last = gi + i_bh * T*K + last_idx * K + tl.arange(0, BK)
200
+ b_g_last = tl.load(p_g_last, mask=m_k, other=0)
201
+
202
+ p_qg = tl.make_block_ptr(qg + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
203
+ p_kg = tl.make_block_ptr(kg + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
204
+ p_ag = tl.make_block_ptr(ag + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
205
+ p_bg = tl.make_block_ptr(bg + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
206
+ else:
207
+ o_A = (bos + i_t * BT + i_i * BC + tl.arange(0, BC)) * H*BT + i_h * BT + i_j * BC
208
+ p_q = tl.make_block_ptr(q + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
209
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
210
+ p_a = tl.make_block_ptr(a + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
211
+ p_b = tl.make_block_ptr(b + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
212
+ p_gi = tl.make_block_ptr(gi + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
213
+ p_ge = tl.make_block_ptr(ge + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
214
+ p_g_last = gi + (bos * H + i_h) * K + last_idx * H * K + tl.arange(0, BK)
215
+ b_g_last = tl.load(p_g_last, mask=m_k, other=0)
216
+ p_qg = tl.make_block_ptr(qg + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
217
+ p_kg = tl.make_block_ptr(kg + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
218
+ p_ag = tl.make_block_ptr(ag + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
219
+ p_bg = tl.make_block_ptr(bg + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
220
+
221
+ b_q = tl.load(p_q, boundary_check=(0, 1))
222
+ b_q = b_q * scale
223
+ b_k = tl.load(p_k, boundary_check=(0, 1))
224
+ b_a = tl.load(p_a, boundary_check=(0, 1))
225
+ b_b = tl.load(p_b, boundary_check=(0, 1))
226
+ b_gi = tl.load(p_gi, boundary_check=(0, 1)).to(tl.float32)
227
+ b_ge = tl.load(p_ge, boundary_check=(0, 1)).to(tl.float32)
228
+
229
+ # deal with decay term.
230
+ g_exp = exp(b_gi)
231
+ g_exp_inv = exp(-b_gi + b_g_last[None, :])
232
+ b_qg = b_q * g_exp
233
+ b_kg = b_k * g_exp_inv
234
+ b_bg = b_b * g_exp_inv
235
+ b_ag = b_a * exp(b_ge)
236
+ tl.store(p_qg, b_qg.to(p_qg.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
237
+ tl.store(p_bg, b_bg.to(p_bg.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
238
+ tl.store(p_ag, b_ag.to(p_ag.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
239
+ tl.store(p_kg, b_kg.to(p_kg.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
240
+ # tl.debug_barrier()
241
+
242
+ b_q = b_q.to(b_k.dtype)
243
+ # inner attn
244
+ for j in range(0, min(BC, T - i_t * BT - i_i * BC)):
245
+ # a trick to index the j-th row of b_k, b_g, b_b
246
+ if GATHER_SUPPORTED:
247
+ row_idx = tl.full([1, BK], j, dtype=tl.int16)
248
+ # [1, BK]
249
+ b_k_j = gather(b_k, row_idx, axis=0)
250
+ b_gk_j = gather(b_gi, row_idx, axis=0)
251
+ b_b_j = gather(b_b, row_idx, axis=0)
252
+ else:
253
+ mask = tl.arange(0, BC) == j
254
+ b_k_j = tl.sum(tl.where(mask[:, None], b_k, 0), 0)[None, :]
255
+ b_gk_j = tl.sum(tl.where(mask[:, None], b_gi, 0), 0)[None, :]
256
+ b_b_j = tl.sum(tl.where(mask[:, None], b_b, 0), 0)[None, :]
257
+ mask = tl.arange(0, BC) == j
258
+ tmp = exp(b_gi - b_gk_j)
259
+ b_A_qk = tl.sum(b_q * b_k_j * tmp, 1)
260
+ b_A_qk = tl.where(o_i >= j, b_A_qk, 0.)
261
+ b_A_qb = tl.sum(b_q * b_b_j * tmp, 1)
262
+ b_A_qb = tl.where(o_i >= j, b_A_qb, 0.)
263
+ tmp2 = exp(b_ge - b_gk_j)
264
+ b_A_ak = tl.sum(b_a * b_k_j * tmp2, 1)
265
+ b_A_ak = tl.where(o_i > j, b_A_ak, 0.)
266
+ b_A_ab = tl.sum(b_a * b_b_j * tmp2, 1)
267
+ b_A_ab = tl.where(o_i > j, b_A_ab, 0.)
268
+ tl.store(Aqk + o_A + j, b_A_qk.to(dtype=Aqk.dtype.element_ty, fp_downcast_rounding="rtne"), mask=m_A)
269
+ tl.store(Aqb + o_A + j, b_A_qb.to(dtype=Aqb.dtype.element_ty, fp_downcast_rounding="rtne"), mask=m_A)
270
+ tl.store(Aab + o_A + j, b_A_ab.to(dtype=Aqb.dtype.element_ty, fp_downcast_rounding="rtne"), mask=m_A)
271
+ tl.store(Aak + o_A + j, b_A_ak.to(dtype=Aqk.dtype.element_ty, fp_downcast_rounding="rtne"), mask=m_A)
272
+
273
+
274
+ def chunk_fwd_intra_dplr_fn(
275
+ q: torch.Tensor,
276
+ k: torch.Tensor,
277
+ a: torch.Tensor,
278
+ b: torch.Tensor,
279
+ gi: torch.Tensor,
280
+ ge: torch.Tensor,
281
+ scale: float,
282
+ chunk_size: int,
283
+ offsets: Optional[torch.LongTensor] = None,
284
+ indices: Optional[torch.LongTensor] = None,
285
+ head_first: bool = True,
286
+ ):
287
+ if head_first:
288
+ B, H, T, K = k.shape
289
+ else:
290
+ B, T, H, K = k.shape
291
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
292
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
293
+ BC = min(16, BT)
294
+ NC = triton.cdiv(BT, BC)
295
+
296
+ Aqk = q.new_empty(B, *((H, T) if head_first else (T, H)), BT, dtype=q.dtype)
297
+ Aqb = q.new_empty(B, *((H, T) if head_first else (T, H)), BT, dtype=q.dtype)
298
+ # involving matrix inverse and it'd be better to use float here.
299
+ Aab = q.new_empty(B, *((H, T) if head_first else (T, H)), BT, dtype=torch.float)
300
+ Aak = q.new_empty(B, *((H, T) if head_first else (T, H)), BT, dtype=torch.float)
301
+ grid = (NT, NC * NC, B * H)
302
+
303
+ chunk_dplr_fwd_A_kernel_intra_sub_inter[grid](
304
+ q=q, k=k, a=a, b=b, gi=gi, ge=ge, Aqk=Aqk, Aqb=Aqb, Aab=Aab, Aak=Aak,
305
+ offsets=offsets, indices=indices,
306
+ scale=scale,
307
+ T=T, H=H, K=K, BT=BT, BC=BC, NC=NC,
308
+ HEAD_FIRST=head_first
309
+ )
310
+ grid = (NT, NC, B * H)
311
+ BK = triton.next_power_of_2(K)
312
+ qg = torch.empty_like(q)
313
+ kg = torch.empty_like(k, dtype=q.dtype)
314
+ ag = torch.empty_like(a, dtype=q.dtype)
315
+ bg = torch.empty_like(b, dtype=q.dtype)
316
+ chunk_dplr_fwd_A_kernel_intra_sub_intra[grid](
317
+ q=q, k=k, a=a, b=b, gi=gi, ge=ge, Aqk=Aqk, Aqb=Aqb, Aab=Aab, Aak=Aak,
318
+ qg=qg, kg=kg, ag=ag, bg=bg,
319
+ offsets=offsets, indices=indices,
320
+ scale=scale,
321
+ T=T, H=H, K=K, BT=BT, BC=BC, BK=BK, HEAD_FIRST=head_first, NC=NC,
322
+ GATHER_SUPPORTED=is_gather_supported
323
+ )
324
+ return Aab, Aqk, Aak, Aqb, qg, kg, ag, bg
fla/ops/generalized_delta_rule/dplr/chunk_h_bwd.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.common.utils import prepare_chunk_offsets
11
+ from fla.ops.utils.op import exp
12
+ from fla.utils import check_shared_mem, use_cuda_graph
13
+
14
+
15
+ @triton.heuristics({
16
+ 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None,
17
+ 'USE_INITIAL_STATE': lambda args: args['dh0'] is not None,
18
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
19
+ })
20
+ @triton.autotune(
21
+ configs=[
22
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
23
+ for num_warps in [2, 4, 8, 16, 32]
24
+ for num_stages in [2, 3, 4]
25
+ ],
26
+ key=['BT', 'BK', 'BV', "V"],
27
+ use_cuda_graph=use_cuda_graph,
28
+ )
29
+ @triton.jit(do_not_specialize=['T'])
30
+ def chunk_dplr_bwd_kernel_dhu(
31
+ qg,
32
+ bg,
33
+ w,
34
+ gk,
35
+ dht,
36
+ dh0,
37
+ do,
38
+ dh,
39
+ dv,
40
+ dv2,
41
+ offsets,
42
+ chunk_offsets,
43
+ T,
44
+ H: tl.constexpr,
45
+ K: tl.constexpr,
46
+ V: tl.constexpr,
47
+ BT: tl.constexpr,
48
+ BC: tl.constexpr,
49
+ BK: tl.constexpr,
50
+ BV: tl.constexpr,
51
+ USE_FINAL_STATE_GRADIENT: tl.constexpr,
52
+ USE_INITIAL_STATE: tl.constexpr,
53
+ USE_OFFSETS: tl.constexpr,
54
+ HEAD_FIRST: tl.constexpr
55
+ ):
56
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
57
+ i_n, i_h = i_nh // H, i_nh % H
58
+ if USE_OFFSETS:
59
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
60
+ T = eos - bos
61
+ NT = tl.cdiv(T, BT)
62
+ boh = tl.load(chunk_offsets + i_n).to(tl.int32)
63
+ else:
64
+ bos, eos = i_n * T, i_n * T + T
65
+ NT = tl.cdiv(T, BT)
66
+ boh = i_n * NT
67
+
68
+ # [BK, BV]
69
+ b_dh = tl.zeros([BK, BV], dtype=tl.float32)
70
+ if USE_FINAL_STATE_GRADIENT:
71
+ p_dht = tl.make_block_ptr(dht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
72
+ b_dh += tl.load(p_dht, boundary_check=(0, 1))
73
+
74
+ mask_k = tl.arange(0, BK) < K
75
+ for i_t in range(NT - 1, -1, -1):
76
+ if HEAD_FIRST:
77
+ p_dh = tl.make_block_ptr(dh + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
78
+ else:
79
+ p_dh = tl.make_block_ptr(dh + ((boh+i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
80
+ tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
81
+ b_dh_tmp = tl.zeros([BK, BV], dtype=tl.float32)
82
+ for i_c in range(tl.cdiv(BT, BC) - 1, -1, -1):
83
+ if HEAD_FIRST:
84
+ p_qg = tl.make_block_ptr(qg + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
85
+ p_bg = tl.make_block_ptr(bg + i_nh * T*K, (T, K), (K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
86
+ p_w = tl.make_block_ptr(w + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
87
+ p_dv = tl.make_block_ptr(dv + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
88
+ p_do = tl.make_block_ptr(do + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
89
+ p_dv2 = tl.make_block_ptr(dv2 + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
90
+ else:
91
+ p_qg = tl.make_block_ptr(qg+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
92
+ p_bg = tl.make_block_ptr(bg+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
93
+ p_w = tl.make_block_ptr(w+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
94
+ p_dv = tl.make_block_ptr(dv+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
95
+ p_do = tl.make_block_ptr(do+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
96
+ p_dv2 = tl.make_block_ptr(dv2+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
97
+ # [BK, BT]
98
+ b_qg = tl.load(p_qg, boundary_check=(0, 1))
99
+ # [BT, BK]
100
+ b_bg = tl.load(p_bg, boundary_check=(0, 1))
101
+ b_w = tl.load(p_w, boundary_check=(0, 1))
102
+ # [BT, V]
103
+ b_do = tl.load(p_do, boundary_check=(0, 1))
104
+ b_dv = tl.load(p_dv, boundary_check=(0, 1))
105
+ b_dv2 = b_dv + tl.dot(b_bg, b_dh.to(b_bg.dtype))
106
+ tl.store(p_dv2, b_dv2.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
107
+ # [BK, BV]
108
+ b_dh_tmp += tl.dot(b_qg, b_do.to(b_qg.dtype))
109
+ b_dh_tmp += tl.dot(b_w, b_dv2.to(b_qg.dtype))
110
+ last_idx = min((i_t + 1) * BT, T) - 1
111
+ if HEAD_FIRST:
112
+ bg_last = tl.load(gk + (i_nh * T + last_idx) * K + tl.arange(0, BK), mask=mask_k)
113
+ else:
114
+ bg_last = tl.load(gk + ((bos + last_idx) * H + i_h) * K + tl.arange(0, BK), mask=mask_k)
115
+ b_dh *= exp(bg_last)[:, None]
116
+ b_dh += b_dh_tmp
117
+
118
+ if USE_INITIAL_STATE:
119
+ p_dh0 = tl.make_block_ptr(dh0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
120
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1))
121
+
122
+
123
+ def chunk_dplr_bwd_dhu(
124
+ qg: torch.Tensor,
125
+ bg: torch.Tensor,
126
+ w: torch.Tensor,
127
+ gk: torch.Tensor,
128
+ h0: torch.Tensor,
129
+ dht: Optional[torch.Tensor],
130
+ do: torch.Tensor,
131
+ dv: torch.Tensor,
132
+ offsets: Optional[torch.LongTensor] = None,
133
+ indices: Optional[torch.LongTensor] = None,
134
+ head_first: bool = True,
135
+ chunk_size: int = 64
136
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
137
+ if head_first:
138
+ B, H, T, K, V = *qg.shape, do.shape[-1]
139
+ else:
140
+ B, T, H, K, V = *qg.shape, do.shape[-1]
141
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
142
+ BK = triton.next_power_of_2(K)
143
+ assert BK <= 256, "current kernel does not support head dimension being larger than 256."
144
+ # H100
145
+ if check_shared_mem('hopper', qg.device.index):
146
+ BV = 64
147
+ BC = 64 if K <= 128 else 32
148
+ elif check_shared_mem('ampere', qg.device.index): # A100
149
+ BV = 32
150
+ BC = 32
151
+ else: # Etc: 4090
152
+ BV = 16
153
+ BC = 16
154
+
155
+ # N: the actual number of sequences in the batch with either equal or variable lengths
156
+ if offsets is None:
157
+ N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
158
+ else:
159
+ N, NT, chunk_offsets = len(offsets) - 1, len(indices), prepare_chunk_offsets(offsets, BT)
160
+
161
+ BC = min(BT, BC)
162
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
163
+ assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization'
164
+
165
+ if head_first:
166
+ dh = qg.new_empty(B, H, NT, K, V)
167
+ else:
168
+ dh = qg.new_empty(B, NT, H, K, V)
169
+ dh0 = torch.empty_like(h0, dtype=torch.float32) if h0 is not None else None
170
+ dv2 = torch.zeros_like(dv)
171
+
172
+ grid = (NK, NV, N * H)
173
+ chunk_dplr_bwd_kernel_dhu[grid](
174
+ qg=qg,
175
+ bg=bg,
176
+ w=w,
177
+ gk=gk,
178
+ dht=dht,
179
+ dh0=dh0,
180
+ do=do,
181
+ dh=dh,
182
+ dv=dv,
183
+ dv2=dv2,
184
+ offsets=offsets,
185
+ chunk_offsets=chunk_offsets,
186
+ T=T,
187
+ H=H,
188
+ K=K,
189
+ V=V,
190
+ BT=BT,
191
+ BC=BC,
192
+ BK=BK,
193
+ BV=BV,
194
+ HEAD_FIRST=head_first
195
+ )
196
+ return dh, dh0, dv2
fla/ops/generalized_delta_rule/dplr/chunk_h_fwd.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.common.utils import prepare_chunk_offsets
11
+ from fla.ops.utils.op import exp
12
+ from fla.utils import check_shared_mem, use_cuda_graph
13
+
14
+
15
+ @triton.heuristics({
16
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
17
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
18
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
19
+ })
20
+ @triton.autotune(
21
+ configs=[
22
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
23
+ for num_warps in [2, 4, 8, 16, 32]
24
+ for num_stages in [2, 3, 4]
25
+ ],
26
+ key=['BT', 'BK', 'BV'],
27
+ use_cuda_graph=use_cuda_graph,
28
+ )
29
+ @triton.jit(do_not_specialize=['T'])
30
+ def chunk_dplr_fwd_kernel_h(
31
+ kg,
32
+ v,
33
+ w,
34
+ bg,
35
+ u,
36
+ v_new,
37
+ gk,
38
+ h,
39
+ h0,
40
+ ht,
41
+ offsets,
42
+ chunk_offsets,
43
+ T,
44
+ H: tl.constexpr,
45
+ K: tl.constexpr,
46
+ V: tl.constexpr,
47
+ BT: tl.constexpr,
48
+ BC: tl.constexpr,
49
+ BK: tl.constexpr,
50
+ BV: tl.constexpr,
51
+ NT: tl.constexpr,
52
+ USE_INITIAL_STATE: tl.constexpr,
53
+ STORE_FINAL_STATE: tl.constexpr,
54
+ USE_OFFSETS: tl.constexpr,
55
+ HEAD_FIRST: tl.constexpr,
56
+ ):
57
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
58
+ i_n, i_h = i_nh // H, i_nh % H
59
+ if USE_OFFSETS:
60
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
61
+ T = eos - bos
62
+ NT = tl.cdiv(T, BT)
63
+ boh = tl.load(chunk_offsets + i_n).to(tl.int32)
64
+ else:
65
+ bos, eos = i_n * T, i_n * T + T
66
+ NT = tl.cdiv(T, BT)
67
+ boh = i_n * NT
68
+
69
+ # [BK, BV]
70
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
71
+ if USE_INITIAL_STATE:
72
+ p_h0 = tl.make_block_ptr(h0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
73
+ b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)
74
+
75
+ for i_t in range(NT):
76
+ if HEAD_FIRST:
77
+ p_h = tl.make_block_ptr(h + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
78
+ else:
79
+ p_h = tl.make_block_ptr(h + ((boh + i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
80
+ tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
81
+
82
+ b_hc = tl.zeros([BK, BV], dtype=tl.float32)
83
+ # since we need to make all DK in the SRAM. we face serve SRAM memory burden. By subchunking we allievate such burden
84
+ for i_c in range(tl.cdiv(min(BT, T - i_t * BT), BC)):
85
+ if HEAD_FIRST:
86
+ p_kg = tl.make_block_ptr(kg + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
87
+ p_bg = tl.make_block_ptr(bg + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
88
+ p_w = tl.make_block_ptr(w + i_nh * T*K, (T, K), (K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
89
+ p_v = tl.make_block_ptr(v + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
90
+ p_u = tl.make_block_ptr(u + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
91
+ p_v_new = tl.make_block_ptr(v_new+i_nh*T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
92
+ else:
93
+ p_kg = tl.make_block_ptr(kg+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
94
+ p_bg = tl.make_block_ptr(bg+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1))
95
+ p_w = tl.make_block_ptr(w+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0))
96
+ p_v = tl.make_block_ptr(v+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
97
+ p_u = tl.make_block_ptr(u+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0))
98
+ p_v_new = tl.make_block_ptr(v_new+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT+i_c*BC, i_v * BV), (BC, BV), (1, 0))
99
+ # [BK, BC]
100
+ b_kg = tl.load(p_kg, boundary_check=(0, 1))
101
+ b_v = tl.load(p_v, boundary_check=(0, 1))
102
+ b_w = tl.load(p_w, boundary_check=(0, 1))
103
+ b_bg = tl.load(p_bg, boundary_check=(0, 1))
104
+ b_v2 = tl.dot(b_w, b_h.to(b_w.dtype)) + tl.load(p_u, boundary_check=(0, 1))
105
+ b_hc += tl.dot(b_kg, b_v)
106
+ b_hc += tl.dot(b_bg.to(b_hc.dtype), b_v2)
107
+ tl.store(p_v_new, b_v2.to(p_v_new.dtype.element_ty), boundary_check=(0, 1))
108
+
109
+ last_idx = min((i_t + 1) * BT, T) - 1
110
+ if HEAD_FIRST:
111
+ b_g_last = tl.load(gk + i_nh * T * K + last_idx * K + tl.arange(0, BK), mask=tl.arange(0, BK) < K).to(tl.float32)
112
+ else:
113
+ b_g_last = tl.load(gk + (bos + last_idx) * H * K + i_h * K +
114
+ tl.arange(0, BK), mask=tl.arange(0, BK) < K).to(tl.float32)
115
+ b_h *= exp(b_g_last[:, None])
116
+ b_h += b_hc
117
+
118
+ if STORE_FINAL_STATE:
119
+ p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
120
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
121
+
122
+
123
+ def chunk_dplr_fwd_h(
124
+ kg: torch.Tensor,
125
+ v: torch.Tensor,
126
+ w: torch.Tensor,
127
+ u: torch.Tensor,
128
+ bg: torch.Tensor,
129
+ gk: torch.Tensor,
130
+ initial_state: Optional[torch.Tensor] = None,
131
+ output_final_state: bool = False,
132
+ offsets: Optional[torch.LongTensor] = None,
133
+ indices: Optional[torch.LongTensor] = None,
134
+ head_first: bool = True,
135
+ chunk_size: int = 64
136
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
137
+ if head_first:
138
+ B, H, T, K, V = *kg.shape, u.shape[-1]
139
+ else:
140
+ B, T, H, K, V = *kg.shape, u.shape[-1]
141
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
142
+ # N: the actual number of sequences in the batch with either equal or variable lengths
143
+ if offsets is None:
144
+ N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
145
+ else:
146
+ N, NT, chunk_offsets = len(offsets) - 1, len(indices), prepare_chunk_offsets(offsets, BT)
147
+ BK = triton.next_power_of_2(K)
148
+ assert BK <= 256, "current kernel does not support head dimension larger than 256."
149
+ # H100 can have larger block size
150
+
151
+ if check_shared_mem('hopper', kg.device.index):
152
+ BV = 64
153
+ BC = 64 if K <= 128 else 32
154
+ elif check_shared_mem('ampere', kg.device.index): # A100
155
+ BV = 32
156
+ BC = 32
157
+ else:
158
+ BV = 16
159
+ BC = 16
160
+
161
+ BC = min(BT, BC)
162
+ NK = triton.cdiv(K, BK)
163
+ NV = triton.cdiv(V, BV)
164
+ assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization'
165
+
166
+ if head_first:
167
+ h = kg.new_empty(B, H, NT, K, V)
168
+ else:
169
+ h = kg.new_empty(B, NT, H, K, V)
170
+ final_state = kg.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None
171
+ v_new = torch.empty_like(u)
172
+ grid = (NK, NV, N * H)
173
+ chunk_dplr_fwd_kernel_h[grid](
174
+ kg=kg,
175
+ v=v,
176
+ w=w,
177
+ bg=bg,
178
+ u=u,
179
+ v_new=v_new,
180
+ h=h,
181
+ gk=gk,
182
+ h0=initial_state,
183
+ ht=final_state,
184
+ offsets=offsets,
185
+ chunk_offsets=chunk_offsets,
186
+ T=T,
187
+ H=H,
188
+ K=K,
189
+ V=V,
190
+ BT=BT,
191
+ BC=BC,
192
+ BK=BK,
193
+ BV=BV,
194
+ NT=NT,
195
+ HEAD_FIRST=head_first
196
+ )
197
+ return h, v_new, final_state
fla/ops/generalized_delta_rule/dplr/fused_recurrent.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.utils.op import exp
11
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard, use_cuda_graph
12
+
13
+
14
+ @triton.heuristics({
15
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
16
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
17
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
18
+ })
19
+ @triton.autotune(
20
+ configs=[
21
+ triton.Config({'BV': BV}, num_warps=num_warps, num_stages=num_stages)
22
+ for BV in [16, 32, 64]
23
+ for num_warps in [2, 4, 8, 16]
24
+ for num_stages in [2, 3, 4]
25
+ ],
26
+ key=['BK'],
27
+ use_cuda_graph=use_cuda_graph,
28
+ )
29
+ @triton.jit(do_not_specialize=['T'])
30
+ def fused_recurrent_dplr_delta_rule_fwd_kernel(
31
+ q,
32
+ k,
33
+ v,
34
+ a,
35
+ b,
36
+ gk,
37
+ o,
38
+ h0,
39
+ ht,
40
+ offsets,
41
+ scale,
42
+ T,
43
+ B: tl.constexpr,
44
+ H: tl.constexpr,
45
+ K: tl.constexpr,
46
+ V: tl.constexpr,
47
+ BK: tl.constexpr,
48
+ BV: tl.constexpr,
49
+ REVERSE: tl.constexpr,
50
+ USE_INITIAL_STATE: tl.constexpr,
51
+ STORE_FINAL_STATE: tl.constexpr,
52
+ USE_OFFSETS: tl.constexpr,
53
+ HEAD_FIRST: tl.constexpr
54
+ ):
55
+ i_v, i_nh = tl.program_id(0).to(tl.int64), tl.program_id(1).to(tl.int64)
56
+ i_n, i_h = i_nh // H, i_nh % H
57
+
58
+ if USE_OFFSETS:
59
+ bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64)
60
+ T = eos - bos
61
+ else:
62
+ bos, eos = i_n * T, i_n * T + T
63
+
64
+ o_k = tl.arange(0, BK)
65
+ o_v = i_v * BV + tl.arange(0, BV)
66
+ if HEAD_FIRST:
67
+ p_q = q + i_nh * T*K + ((T-1) * K if REVERSE else 0) + o_k
68
+ p_k = k + i_nh * T*K + ((T-1) * K if REVERSE else 0) + o_k
69
+ p_a = a + i_nh * T*K + ((T-1) * K if REVERSE else 0) + o_k
70
+ p_b = b + i_nh * T*K + ((T-1) * K if REVERSE else 0) + o_k
71
+ p_gk = gk + i_nh * T*K + ((T-1) * K if REVERSE else 0) + o_k
72
+ p_v = v + i_nh * T*V + ((T-1) * V if REVERSE else 0) + o_v
73
+ p_o = o + i_nh * T*V + ((T-1) * V if REVERSE else 0) + o_v
74
+
75
+ else:
76
+ p_q = q + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
77
+ p_k = k + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
78
+ p_a = a + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
79
+ p_b = b + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
80
+ p_gk = gk + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k
81
+ p_v = v + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + o_v
82
+ p_o = o + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + o_v
83
+
84
+ mask_k = o_k < K
85
+ mask_v = o_v < V
86
+ mask_h = mask_k[None, :] & mask_v[:, None]
87
+ b_h = tl.zeros([BV, BK], dtype=tl.float32)
88
+
89
+ if USE_INITIAL_STATE:
90
+ p_h0 = h0 + i_nh * K*V + o_k[None, :] * V + o_v[:, None]
91
+ b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
92
+
93
+ for _ in range(0, T):
94
+ b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale
95
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
96
+ b_a = tl.load(p_a, mask=mask_k, other=0).to(tl.float32)
97
+ b_b = tl.load(p_b, mask=mask_k, other=0).to(tl.float32)
98
+ b_gk = tl.load(p_gk, mask=mask_k, other=0).to(tl.float32)
99
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
100
+
101
+ tmp = tl.sum(b_h * b_a[None, :], axis=1)
102
+ b_h = exp(b_gk)[None, :] * b_h + (tmp[:, None] * b_b[None, :] + b_k[None, :] * b_v[:, None])
103
+ b_o = tl.sum(b_h * b_q[None, :], axis=1)
104
+
105
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
106
+ p_q += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K
107
+ p_k += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K
108
+ p_a += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K
109
+ p_b += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K
110
+ p_gk += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K
111
+ p_v += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V
112
+ p_o += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V
113
+
114
+ if STORE_FINAL_STATE:
115
+ p_ht = ht + i_nh * K*V + o_k[None, :] * V + o_v[:, None]
116
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
117
+
118
+
119
+ def fused_recurrent_dplr_delta_rule_fwd(
120
+ q: torch.Tensor,
121
+ k: torch.Tensor,
122
+ v: torch.Tensor,
123
+ a: torch.Tensor,
124
+ b: torch.Tensor,
125
+ gk: torch.Tensor,
126
+ scale: Optional[float] = 1.0,
127
+ initial_state: Optional[torch.Tensor] = None,
128
+ output_final_state: bool = False,
129
+ reverse: bool = False,
130
+ offsets: Optional[torch.LongTensor] = None,
131
+ head_first: bool = True
132
+ ):
133
+ if head_first:
134
+ B, H, T, K, V = *k.shape, v.shape[-1]
135
+ else:
136
+ B, T, H, K, V = *k.shape, v.shape[-1]
137
+ N = B if offsets is None else len(offsets) - 1
138
+ BK = triton.next_power_of_2(K)
139
+
140
+ h0 = initial_state
141
+ if output_final_state:
142
+ ht = q.new_empty(N, H, K, V, dtype=torch.float32)
143
+ else:
144
+ ht = None
145
+ o = torch.empty_like(v)
146
+
147
+ def grid(meta): return (triton.cdiv(V, meta['BV']), N * H)
148
+ fused_recurrent_dplr_delta_rule_fwd_kernel[grid](
149
+ q,
150
+ k,
151
+ v,
152
+ a,
153
+ b,
154
+ gk,
155
+ o,
156
+ h0,
157
+ ht,
158
+ offsets,
159
+ scale,
160
+ T=T,
161
+ B=B,
162
+ H=H,
163
+ K=K,
164
+ V=V,
165
+ BK=BK,
166
+ REVERSE=reverse,
167
+ HEAD_FIRST=head_first
168
+ )
169
+ return o, ht
170
+
171
+
172
+ class FusedRecurrentDPLRDeltaRuleFunction(torch.autograd.Function):
173
+
174
+ @staticmethod
175
+ @input_guard
176
+ @autocast_custom_fwd
177
+ def forward(
178
+ ctx,
179
+ q: torch.Tensor,
180
+ k: torch.Tensor,
181
+ v: torch.Tensor,
182
+ a: torch.Tensor,
183
+ b: torch.Tensor,
184
+ gk: torch.Tensor,
185
+ scale: Optional[float] = 1.0,
186
+ initial_state: Optional[torch.Tensor] = None,
187
+ output_final_state: bool = False,
188
+ reverse: bool = False,
189
+ offsets: Optional[torch.LongTensor] = None,
190
+ head_first: bool = False
191
+ ):
192
+ o, ht = fused_recurrent_dplr_delta_rule_fwd(
193
+ q=q,
194
+ k=k,
195
+ v=v,
196
+ a=a,
197
+ b=b,
198
+ gk=gk,
199
+ scale=scale,
200
+ initial_state=initial_state,
201
+ output_final_state=output_final_state,
202
+ reverse=reverse,
203
+ offsets=offsets,
204
+ head_first=head_first
205
+ )
206
+ return o, ht
207
+
208
+ @staticmethod
209
+ @input_guard
210
+ @autocast_custom_bwd
211
+ def backward(ctx, do, dht):
212
+ raise NotImplementedError(
213
+ "Backward pass for fused_recurrent_dplr_delta_rule is not implemented and will not be supported. "
214
+ "This kernel is only for inference. "
215
+ "For training, please use `chunk_dplr_delta_rule`."
216
+ )
217
+
218
+
219
+ def fused_recurrent_dplr_delta_rule(
220
+ q: torch.Tensor,
221
+ k: torch.Tensor,
222
+ v: torch.Tensor,
223
+ a: torch.Tensor,
224
+ b: torch.Tensor,
225
+ gk: torch.Tensor,
226
+ scale: Optional[float] = 1.0,
227
+ initial_state: Optional[torch.Tensor] = None,
228
+ output_final_state: bool = False,
229
+ reverse: bool = False,
230
+ cu_seqlens: Optional[torch.Tensor] = None,
231
+ head_first: bool = False
232
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
233
+ r"""
234
+ This function computes the recurrence S_t = S_t @ (I + a_t b_t^T) + v_t k_t^T in a recurrent manner.
235
+
236
+ Args:
237
+ q (torch.Tensor):
238
+ queries of shape `[B, H, T, K]`
239
+ k (torch.Tensor):
240
+ keys of shape `[B, H, T, K]`
241
+ v (torch.Tensor):
242
+ values of shape `[B, H, T, V]`
243
+ a (torch.Tensor):
244
+ as of shape `[B, H, T, K]`
245
+ b (torch.Tensor):
246
+ bs of shape `[B, H, T, K]`
247
+ gk (torch.Tensor):
248
+ gk of shape `[B, H, T, K]`
249
+ scale (Optional[int]):
250
+ Scale factor for the RetNet attention scores.
251
+ If None, it will default to `1 / sqrt(K)`. Default: `1.0`.
252
+ initial_state (Optional[torch.Tensor]):
253
+ Initial state of shape `[B, H, K, V]`. Default: `None`.
254
+ output_final_state (Optional[bool]):
255
+ Whether to output the final state of shape `[B, H, K, V]`. Default: `False`.
256
+ reverse (Optional[bool]):
257
+ If `True`, process the state passing in reverse order. Default: `False`.
258
+ cu_seqlens (Optional[torch.Tensor]):
259
+ Cumulative sequence lengths of shape `[N + 1]` used for variable-length training,
260
+ consistent with the FlashAttention API.
261
+ head_first (Optional[bool]):
262
+ Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
263
+ Default: `False`.
264
+ """
265
+ if cu_seqlens is not None:
266
+ if q.shape[0] != 1:
267
+ raise ValueError(f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
268
+ f"Please flatten variable-length inputs before processing.")
269
+ if head_first:
270
+ raise RuntimeError("Sequences with variable lengths are not supported for head-first mode")
271
+ if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
272
+ raise ValueError(f"The number of initial states is expected to be equal to the number of input sequences, "
273
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}.")
274
+ if scale is None:
275
+ scale = q.shape[-1] ** -0.5
276
+ else:
277
+ assert scale > 0, "scale must be positive"
278
+ o, final_state = FusedRecurrentDPLRDeltaRuleFunction.apply(
279
+ q,
280
+ k,
281
+ v,
282
+ a,
283
+ b,
284
+ gk,
285
+ scale,
286
+ initial_state,
287
+ output_final_state,
288
+ reverse,
289
+ cu_seqlens,
290
+ head_first
291
+ )
292
+ return o, final_state
fla/ops/generalized_delta_rule/dplr/naive.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import torch
4
+ from einops import rearrange
5
+
6
+ # S_t = S_t @ (I + alpha_t beta_t^T) + v_t k_t^T
7
+ # q, k, alpha, beta [B, H, L, D_K]
8
+ # v [B, H, L, D_V]
9
+
10
+
11
+ def dplr_recurrence(q, k, v, alpha, beta, gk, initial_state=None, output_final_state=True):
12
+ orig_dtype = q.dtype
13
+ b, h, l, d_k = q.shape
14
+ q, k, v, beta, gk = map(lambda x: x.float(), [q, k, v, beta, gk])
15
+ d_v = v.shape[-1]
16
+ o = torch.zeros_like(v)
17
+ S = torch.zeros(b, h, d_k, d_v).to(v)
18
+ q = q * (d_k ** -0.5)
19
+
20
+ if initial_state is not None:
21
+ S += initial_state
22
+
23
+ for i in range(l):
24
+ _k = k[:, :, i]
25
+ _q = q[:, :, i]
26
+ _v = v[:, :, i]
27
+ _alpha = alpha[:, :, i].clone()
28
+ _beta = beta[:, :, i].clone()
29
+ _kv = _k[..., None] * _v[..., None, :] + (S.clone() * _alpha[..., None]).sum(-2, keepdim=True) * _beta[..., None]
30
+ S = S.clone() * gk[:, :, i].exp()[..., None] + _kv
31
+ o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q, S)
32
+ S = None if output_final_state is False else S
33
+ return o.to(orig_dtype), S
34
+
35
+
36
+ def dplr_chunkwise(q, k, v, alpha, beta, gk, initial_state=None, output_final_state=True, chunk_size=32):
37
+ b, h, l, d_k = q.shape
38
+ d_v = v.shape[-1]
39
+ q = q * (d_k ** -0.5)
40
+ v = v
41
+ assert l % chunk_size == 0
42
+
43
+ S = k.new_zeros(b, h, d_k, d_v).to(q)
44
+ if initial_state is not None:
45
+ S += initial_state
46
+
47
+ # note that diagonal is masked.
48
+ mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=0)
49
+ q, k, v, alpha, beta, gk = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d',
50
+ c=chunk_size).float(), [q, k, v, alpha, beta, gk])
51
+
52
+ gk_cumsum = gk.cumsum(-2)
53
+
54
+ # v2 = (alpha @ k.transpose(-1, -2)).masked_fill_(mask, 0) @ v
55
+ A_ab = torch.zeros(b, h, l // chunk_size, chunk_size, chunk_size).to(q.device)
56
+ A_qk = torch.zeros(b, h, l // chunk_size, chunk_size, chunk_size).to(q.device)
57
+ A_ak = torch.zeros(b, h, l // chunk_size, chunk_size, chunk_size).to(q.device)
58
+ A_qb = torch.zeros(b, h, l // chunk_size, chunk_size, chunk_size).to(q.device)
59
+
60
+ for i in range(chunk_size):
61
+ alpha_i = alpha[:, :, :, i, None]
62
+ q_i = q[:, :, :, i, None]
63
+ gk_i = gk_cumsum[:, :, :, i, None]
64
+ mask = (torch.arange(chunk_size) <= i).to(q.device)
65
+ attn_i = (gk_i - gk_cumsum).masked_fill(~mask.unsqueeze(-1), float('-inf')).exp()
66
+ A_qk[:, :, :, i, :] = (q_i * k * attn_i).sum(-1).clone()
67
+ A_qb[:, :, :, i, :] = (q_i * beta * attn_i).sum(-1).clone()
68
+ mask = (torch.arange(chunk_size) < i).to(q.device)
69
+ # shift by one.
70
+ attn_i = (gk_i - gk[:, :, :, i, None] - gk_cumsum).masked_fill(~mask.unsqueeze(-1), float('-inf')).exp()
71
+ A_ab[:, :, :, i, :] = (alpha_i * beta * attn_i).sum(-1).clone()
72
+ A_ak[:, :, :, i, :] = (alpha_i * k * attn_i).sum(-1).clone()
73
+
74
+ A_ab = A_ab
75
+ for i in range(1, chunk_size):
76
+ A_ab[..., i, :i] = A_ab[..., i, :i].clone() + (A_ab[..., i, :, None].clone() * A_ab[..., :, :i].clone()).sum(-2)
77
+
78
+ A_ab = A_ab + torch.eye(chunk_size, dtype=torch.float, device=q.device)
79
+ u = A_ab @ (A_ak @ v)
80
+ w = A_ab @ ((gk_cumsum-gk).exp() * alpha)
81
+
82
+ o = torch.zeros_like(v)
83
+ mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=1)
84
+ for i in range(0, l // chunk_size):
85
+ q_i, k_i, v_i, u_i, w_i, beta_i = q[:, :, i], k[:, :, i], v[:, :, i], u[:, :, i], w[:, :, i], beta[:, :, i]
86
+ v2_i = u_i + w_i @ S
87
+
88
+ o_1 = A_qk[:, :, i] @ v_i
89
+ o_2 = A_qb[:, :, i] @ v2_i
90
+ o_3 = (q_i * gk_cumsum[:, :, i].exp()) @ S
91
+ o[:, :, i] = o_1 + o_2 + o_3
92
+ decay = (gk_cumsum[:, :, i, -1, None] - gk_cumsum[:, :, i]).exp()
93
+ S = S*gk_cumsum[:, :, i, -1, :, None].exp() + (k_i * decay).transpose(-1, -2) @ v_i + \
94
+ (beta_i * decay).transpose(-1, -2) @ v2_i
95
+ S = None if output_final_state is False else S
96
+ return rearrange(o, 'b h n c d -> b h (n c) d'), S
fla/ops/generalized_delta_rule/iplr/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (328 Bytes). View file
 
fla/ops/generalized_delta_rule/iplr/__pycache__/wy_fast.cpython-312.pyc ADDED
Binary file (23.1 kB). View file
 
fla/ops/generalized_delta_rule/iplr/wy_fast.py ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
4
+
5
+ from typing import Optional, Tuple
6
+
7
+ import torch
8
+ import triton
9
+ import triton.language as tl
10
+
11
+ from fla.utils import check_shared_mem, is_nvidia_hopper
12
+
13
+ NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8]
14
+
15
+
16
+ @triton.heuristics({
17
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
18
+ })
19
+ @triton.autotune(
20
+ configs=[
21
+ triton.Config({}, num_warps=num_warps)
22
+ for num_warps in [1, 2, 4, 8, 16]
23
+ ],
24
+ key=['BK']
25
+ )
26
+ @triton.jit(do_not_specialize=['T'])
27
+ def fwd_prepare_wy_repr_kernel_chunk32(
28
+ a,
29
+ b,
30
+ A,
31
+ offsets,
32
+ indices,
33
+ T,
34
+ H: tl.constexpr,
35
+ K: tl.constexpr,
36
+ BT: tl.constexpr,
37
+ BK: tl.constexpr,
38
+ BC: tl.constexpr, # dummy placeholder
39
+ USE_OFFSETS: tl.constexpr,
40
+ HEAD_FIRST: tl.constexpr,
41
+ ):
42
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
43
+ i_b, i_h = i_bh // H, i_bh % H
44
+ if USE_OFFSETS:
45
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
46
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
47
+ T = eos - bos
48
+ else:
49
+ bos, eos = i_b * T, i_b * T + T
50
+
51
+ b_A = tl.zeros([BT, BT], dtype=tl.float32)
52
+ for i_k in range(tl.cdiv(K, BK)):
53
+ if HEAD_FIRST:
54
+ p_a = tl.make_block_ptr(a + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
55
+ p_b = tl.make_block_ptr(b + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
56
+ else:
57
+ p_a = tl.make_block_ptr(a + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
58
+ p_b = tl.make_block_ptr(b + (bos * H + i_h) * K, (K, T), (1, K*H), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
59
+ b_a = tl.load(p_a, boundary_check=(0, 1))
60
+ b_b = tl.load(p_b, boundary_check=(0, 1))
61
+ b_A += tl.dot(b_a, b_b)
62
+
63
+ b_A = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_A, 0)
64
+ for i in range(1, BT):
65
+ mask = tl.arange(0, BT) == i
66
+ b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0)
67
+ b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BT) < i)
68
+ b_A = tl.where(mask[:, None], b_a, b_A)
69
+ b_A += tl.arange(0, BT)[:, None] == tl.arange(0, BT)[None, :]
70
+
71
+ if HEAD_FIRST:
72
+ p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
73
+ else:
74
+ p_A = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
75
+ tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))
76
+
77
+
78
+ @triton.heuristics({
79
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
80
+ })
81
+ @triton.autotune(
82
+ configs=[
83
+ triton.Config({}, num_warps=num_warps)
84
+ for num_warps in [1, 2, 4, 8, 16]
85
+ ],
86
+ key=['BK']
87
+ )
88
+ @triton.jit(do_not_specialize=['T'])
89
+ def fwd_prepare_wy_repr_kernel_chunk64(
90
+ a,
91
+ b,
92
+ A,
93
+ offsets,
94
+ indices,
95
+ T,
96
+ H: tl.constexpr,
97
+ K: tl.constexpr,
98
+ BT: tl.constexpr,
99
+ BK: tl.constexpr,
100
+ BC: tl.constexpr,
101
+ USE_OFFSETS: tl.constexpr,
102
+ HEAD_FIRST: tl.constexpr
103
+ ):
104
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
105
+ i_b, i_h = i_bh // H, i_bh % H
106
+ if USE_OFFSETS:
107
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
108
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
109
+ T = eos - bos
110
+ else:
111
+ bos, eos = i_b * T, i_b * T + T
112
+
113
+ b_A = tl.zeros([BC, BC], dtype=tl.float32)
114
+ b_A2 = tl.zeros([BC, BC], dtype=tl.float32)
115
+ b_A3 = tl.zeros([BC, BC], dtype=tl.float32)
116
+
117
+ for i_k in range(tl.cdiv(K, BK)):
118
+ if HEAD_FIRST:
119
+ p_a1 = tl.make_block_ptr(a + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0))
120
+ p_a2 = tl.make_block_ptr(a + i_bh * T*K, (T, K), (K, 1), (i_t * BT + BC, i_k * BK), (BC, BK), (1, 0))
121
+ p_b1 = tl.make_block_ptr(b + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BC), (0, 1))
122
+ p_b2 = tl.make_block_ptr(b + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + BC), (BK, BC), (0, 1))
123
+ else:
124
+ p_a1 = tl.make_block_ptr(a + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0))
125
+ p_a2 = tl.make_block_ptr(a + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + BC, i_k * BK), (BC, BK), (1, 0))
126
+ p_b1 = tl.make_block_ptr(b + (bos * H + i_h) * K, (K, T), (1, K*H), (i_k * BK, i_t * BT), (BK, BC), (0, 1))
127
+ p_b2 = tl.make_block_ptr(b + (bos * H + i_h) * K, (K, T), (1, K*H), (i_k * BK, i_t * BT + BC), (BK, BC), (0, 1))
128
+ b_a1 = tl.load(p_a1, boundary_check=(0, 1))
129
+ b_a2 = tl.load(p_a2, boundary_check=(0, 1))
130
+ b_b1 = tl.load(p_b1, boundary_check=(0, 1))
131
+ b_b2 = tl.load(p_b2, boundary_check=(0, 1))
132
+ b_A += tl.dot(b_a1, b_b1, allow_tf32=False)
133
+ b_A2 += tl.dot(b_a2, b_b2, allow_tf32=False)
134
+ b_A3 += tl.dot(b_a2, b_b1, allow_tf32=False)
135
+
136
+ b_A = tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_A, 0)
137
+ b_A2 = tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_A2, 0)
138
+
139
+ for i in range(1, BC):
140
+ mask = tl.arange(0, BC) == i
141
+ b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0)
142
+ b_a2 = tl.sum(tl.where(mask[:, None], b_A2, 0), 0)
143
+ b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BC) < i)
144
+ b_a2 = b_a2 + tl.sum(b_a2[:, None] * b_A2, 0) * (tl.arange(0, BC) < i)
145
+ b_A = tl.where(mask[:, None], b_a, b_A)
146
+ b_A2 = tl.where(mask[:, None], b_a2, b_A2)
147
+
148
+ # blockwise computation of lower triangular matrix's inverse
149
+ # i.e., [A11, 0; A21, A22]^-1 = [A11^-1, 0; -A22^-1 A21 A11^-1, A22^-1]
150
+ b_A += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :]
151
+ b_A2 += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :]
152
+ b_A3 = tl.dot(tl.dot(b_A2, b_A3, allow_tf32=False), b_A, allow_tf32=False)
153
+
154
+ if HEAD_FIRST:
155
+ p_A1 = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BC, BC), (1, 0))
156
+ p_A2 = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0))
157
+ p_A3 = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0))
158
+ p_A4 = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, BC), (BC, BC), (1, 0))
159
+ else:
160
+ p_A1 = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BC, BC), (1, 0))
161
+ p_A2 = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0))
162
+ p_A3 = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0))
163
+ p_A4 = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, BC), (BC, BC), (1, 0))
164
+ tl.store(p_A1, b_A.to(p_A1.dtype.element_ty), boundary_check=(0, 1))
165
+ tl.store(p_A2, b_A2.to(p_A2.dtype.element_ty), boundary_check=(0, 1))
166
+ tl.store(p_A3, b_A3.to(p_A3.dtype.element_ty), boundary_check=(0, 1))
167
+ # causal mask
168
+ tl.store(p_A4, tl.zeros([BC, BC], dtype=tl.float32).to(p_A4.dtype.element_ty), boundary_check=(0, 1))
169
+
170
+
171
+ @triton.heuristics({
172
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
173
+ })
174
+ @triton.autotune(
175
+ configs=[
176
+ triton.Config({}, num_warps=num_warps)
177
+ for num_warps in NUM_WARPS
178
+ ],
179
+ key=['BT', 'BK', 'BV']
180
+ )
181
+ @triton.jit(do_not_specialize=['T'])
182
+ def fwd_wu_kernel(
183
+ w,
184
+ u,
185
+ a,
186
+ k,
187
+ v,
188
+ A,
189
+ offsets,
190
+ indices,
191
+ T,
192
+ H: tl.constexpr,
193
+ K: tl.constexpr,
194
+ V: tl.constexpr,
195
+ BT: tl.constexpr,
196
+ BK: tl.constexpr,
197
+ BV: tl.constexpr,
198
+ USE_OFFSETS: tl.constexpr,
199
+ HEAD_FIRST: tl.constexpr
200
+ ):
201
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
202
+ i_b, i_h = i_bh // H, i_bh % H
203
+ if USE_OFFSETS:
204
+ i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
205
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
206
+ T = eos - bos
207
+ else:
208
+ bos, eos = i_b * T, i_b * T + T
209
+
210
+ if HEAD_FIRST:
211
+ p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
212
+ else:
213
+ p_A = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
214
+
215
+ b_A = tl.load(p_A, boundary_check=(0, 1))
216
+ b_Aak = tl.zeros([BT, BT], dtype=tl.float32)
217
+
218
+ for i_k in range(tl.cdiv(K, BK)):
219
+ if HEAD_FIRST:
220
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
221
+ p_a = tl.make_block_ptr(a + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
222
+ p_w = tl.make_block_ptr(w + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
223
+ else:
224
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
225
+ p_a = tl.make_block_ptr(a + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
226
+ p_w = tl.make_block_ptr(w + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
227
+ b_k = tl.load(p_k, boundary_check=(0, 1))
228
+ b_a = tl.load(p_a, boundary_check=(0, 1))
229
+ b_w = tl.dot(b_A, b_a)
230
+ b_Aak += tl.dot(b_a, tl.trans(b_k))
231
+ tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1))
232
+
233
+ b_Aak = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_Aak, 0)
234
+ b_Aak = b_Aak.to(k.dtype.element_ty)
235
+
236
+ for i_v in range(tl.cdiv(V, BV)):
237
+ if HEAD_FIRST:
238
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
239
+ p_u = tl.make_block_ptr(u + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
240
+ else:
241
+ p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
242
+ p_u = tl.make_block_ptr(u + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
243
+ b_v = tl.load(p_v, boundary_check=(0, 1))
244
+ b_v = tl.dot(b_Aak, b_v).to(v.dtype.element_ty)
245
+ b_u = tl.dot(b_A, b_v)
246
+ tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1))
247
+
248
+
249
+ def fwd_prepare_wy_repr(
250
+ a: torch.Tensor,
251
+ b: torch.Tensor,
252
+ v: torch.Tensor,
253
+ k: torch.Tensor,
254
+ offsets: Optional[torch.LongTensor],
255
+ indices: Optional[torch.LongTensor],
256
+ head_first: bool = True,
257
+ chunk_size: int = 64
258
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
259
+ if head_first:
260
+ B, H, T, K = a.shape
261
+ else:
262
+ B, T, H, K = a.shape
263
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
264
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
265
+ BC = min(BT, 32)
266
+ BK = min(triton.next_power_of_2(K), 64)
267
+
268
+ A = torch.empty(B, *((H, T) if head_first else (T, H)), BT, device=a.device, dtype=a.dtype)
269
+ fwd_fn = fwd_prepare_wy_repr_kernel_chunk64 if BT == 64 else fwd_prepare_wy_repr_kernel_chunk32
270
+
271
+ fwd_fn[(NT, B * H)](
272
+ a=a,
273
+ b=b,
274
+ A=A,
275
+ offsets=offsets,
276
+ indices=indices,
277
+ T=T,
278
+ H=H,
279
+ K=K,
280
+ BT=BT,
281
+ BK=BK,
282
+ BC=BC,
283
+ HEAD_FIRST=head_first
284
+ )
285
+ w, u = fwd_wu(
286
+ a=a,
287
+ v=v,
288
+ k=k,
289
+ A=A,
290
+ offsets=offsets,
291
+ indices=indices,
292
+ head_first=head_first,
293
+ chunk_size=chunk_size
294
+ )
295
+ return w, u, A
296
+
297
+
298
+ def fwd_wu(
299
+ a: torch.Tensor,
300
+ v: torch.Tensor,
301
+ k: torch.Tensor,
302
+ A: torch.Tensor,
303
+ offsets: Optional[torch.LongTensor],
304
+ indices: Optional[torch.LongTensor],
305
+ head_first: bool,
306
+ chunk_size: int
307
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
308
+ if head_first:
309
+ B, H, T, K, V = *a.shape, v.shape[-1]
310
+ else:
311
+ B, T, H, K, V = *a.shape, v.shape[-1]
312
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
313
+ NT = triton.cdiv(T, BT) if offsets is None else len(indices)
314
+ CONST_TILING = 64 if check_shared_mem() else 32
315
+ BK = min(triton.next_power_of_2(K), CONST_TILING)
316
+ BV = min(triton.next_power_of_2(V), CONST_TILING)
317
+
318
+ u = torch.empty_like(v)
319
+ w = torch.empty_like(a)
320
+ fwd_wu_kernel[(NT, B*H)](
321
+ a=a,
322
+ v=v,
323
+ w=w,
324
+ u=u,
325
+ A=A,
326
+ k=k,
327
+ offsets=offsets,
328
+ indices=indices,
329
+ T=T,
330
+ H=H,
331
+ K=K,
332
+ V=V,
333
+ BT=BT,
334
+ BK=BK,
335
+ BV=BV,
336
+ HEAD_FIRST=head_first
337
+ )
338
+ return w, u
fla/ops/gla/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (333 Bytes). View file
 
fla/ops/gla/__pycache__/chunk.cpython-312.pyc ADDED
Binary file (81.8 kB). View file
 
fla/ops/gla/__pycache__/fused_chunk.cpython-312.pyc ADDED
Binary file (35.3 kB). View file
 
fla/ops/gla/__pycache__/fused_recurrent.cpython-312.pyc ADDED
Binary file (5.69 kB). View file
 
fla/ops/gsa/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .chunk import chunk_gsa
4
+ from .fused_recurrent import fused_recurrent_gsa
5
+
6
+ __all__ = [
7
+ 'chunk_gsa',
8
+ 'fused_recurrent_gsa'
9
+ ]
fla/ops/gsa/naive.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Optional
4
+
5
+ import torch
6
+ from einops import repeat
7
+
8
+
9
+ def naive_recurrent_gsa(
10
+ q: torch.Tensor,
11
+ k: torch.Tensor,
12
+ v: torch.Tensor,
13
+ s: torch.Tensor,
14
+ g: Optional[torch.Tensor] = None,
15
+ scale: Optional[int] = None,
16
+ initial_state: Optional[torch.Tensor] = None,
17
+ output_final_state: Optional[bool] = False
18
+ ) -> torch.Tensor:
19
+ dtype = q.dtype
20
+
21
+ NG = q.shape[1]//k.shape[1]
22
+ # [batch_size, n_heads, seq_len, n_slots]
23
+ if g is None:
24
+ z = s.float().logcumsumexp(2)
25
+ g = torch.cat((z[:, :, :1], z[:, :, :-1]), 2) - z
26
+ s = torch.exp(s - z)
27
+ q, k, v, s, g = map(lambda x: x.float(), (q, k, v, s, g))
28
+ k, v, s, g = map(lambda x: repeat(x, 'b h t d -> b (h g) t d', g=NG), (k, v, s, g))
29
+ if initial_state is not None:
30
+ initial_state = tuple(map(lambda x: repeat(x, 'b h k v -> b (h g) k v', g=NG), initial_state))
31
+
32
+ B, H, T, K, V, M = *q.shape, v.shape[-1], s.shape[-1]
33
+
34
+ hk = torch.zeros(B, H, K, M, dtype=torch.float, device=q.device)
35
+ ok = torch.zeros_like(s)
36
+
37
+ if scale is None:
38
+ scale = q.shape[-1] ** -0.5
39
+
40
+ final_state = None
41
+ if initial_state is not None:
42
+ hk += initial_state[0]
43
+
44
+ for i in range(T):
45
+ q_i = q[:, :, i] * scale
46
+ k_i = k[:, :, i]
47
+ v_i = s[:, :, i]
48
+ g_i = g[:, :, i].exp()
49
+ hk = hk * g_i[..., None, :] + k_i[..., None] * v_i[..., None, :]
50
+ ok[:, :, i] = (q_i[..., None] * hk).sum(-2)
51
+
52
+ qv = ok.softmax(-1)
53
+ hv = torch.zeros(B, H, M, V, dtype=torch.float, device=q.device)
54
+ ov = torch.zeros_like(v)
55
+ if initial_state is not None:
56
+ hv += initial_state[1]
57
+
58
+ for i in range(T):
59
+ q_i = qv[:, :, i]
60
+ k_i = s[:, :, i]
61
+ v_i = v[:, :, i]
62
+ g_i = g[:, :, i].exp()
63
+ hv = hv * g_i[..., :, None] + k_i[..., None] * v_i[..., None, :]
64
+ ov[:, :, i] = (q_i[..., None] * hv).sum(-2)
65
+
66
+ if output_final_state:
67
+ final_state = (hk.view(B, -1, NG, K, M)[:, :, 0], hv.view(B, -1, NG, M, V)[:, :, 0])
68
+ return ov.to(dtype), final_state
fla/ops/hgrn/__pycache__/fused_recurrent.cpython-312.pyc ADDED
Binary file (14.3 kB). View file
 
fla/ops/hgrn/chunk.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ # this function implements the chunkwise form of HGRN, inspired by
5
+ # [Volodymyr Kyrylov in his blog post](https://proger.github.io/posts/scan/chunk.html)
6
+ # also refer to the `accelerated-scan` lib: https://github.com/proger/accelerated-scan
7
+
8
+ # from tests on H800, with B, D = 16, 128, we see that the chunk can be greatly faster than the recurrent:
9
+ #
10
+ # Performance:
11
+ # seq_len chunk recurrent chunk_bwd recurrent_bwd
12
+ # 0 128.0 0.039360 0.061056 0.312160 0.205008
13
+ # 1 256.0 0.045824 0.123712 0.308784 0.297696
14
+ # 2 512.0 0.058688 0.241952 0.310720 0.626528
15
+ # 3 1024.0 0.088288 0.476992 0.313184 1.333152
16
+ # 4 2048.0 0.169472 0.943264 0.452464 2.724864
17
+ # 5 4096.0 0.329920 1.886144 0.881600 5.551520
18
+ # 6 8192.0 0.647872 3.755040 1.740496 11.117184
19
+ # 7 16384.0 1.272064 7.520576 3.446608 22.362528
20
+
21
+ from typing import Tuple
22
+
23
+ import torch
24
+ import triton
25
+ import triton.language as tl
26
+
27
+ from fla.ops.utils.op import exp
28
+ from fla.utils import input_guard
29
+
30
+
31
+ @triton.autotune(
32
+ configs=[
33
+ triton.Config({'BD': 32}, num_warps=1),
34
+ triton.Config({'BD': 32}, num_warps=2),
35
+ triton.Config({'BD': 32}, num_warps=4),
36
+ triton.Config({'BD': 32}, num_warps=8),
37
+ triton.Config({'BD': 64}, num_warps=1),
38
+ triton.Config({'BD': 64}, num_warps=2),
39
+ triton.Config({'BD': 64}, num_warps=4),
40
+ triton.Config({'BD': 64}, num_warps=8),
41
+ triton.Config({'BD': 128}, num_warps=1),
42
+ triton.Config({'BD': 128}, num_warps=2),
43
+ triton.Config({'BD': 128}, num_warps=4),
44
+ triton.Config({'BD': 128}, num_warps=8),
45
+ ],
46
+ key=['D']
47
+ )
48
+ @triton.jit(do_not_specialize=['T'])
49
+ def chunk_hgrn_fwd_kernel_h(
50
+ x,
51
+ g,
52
+ gc,
53
+ o,
54
+ h0,
55
+ T,
56
+ D: tl.constexpr,
57
+ BT: tl.constexpr,
58
+ BD: tl.constexpr,
59
+ USE_INITIAL_STATE: tl.constexpr
60
+ ):
61
+ i_d, i_t, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2)
62
+ o_d = i_d * BD + tl.arange(0, BD)
63
+ mask = o_d < D
64
+
65
+ p_x = x + i_b * T * D + i_t * BT * D + o_d
66
+ p_g = g + i_b * T * D + i_t * BT * D + o_d
67
+ p_gc = gc + i_b * T * D + i_t * BT * D + o_d
68
+ p_o = o + i_b * T * D + i_t * BT * D + o_d
69
+
70
+ b_h = tl.zeros([BD], dtype=tl.float32)
71
+ b_gc = tl.zeros([BD], dtype=tl.float32)
72
+ if USE_INITIAL_STATE:
73
+ if i_t == 0:
74
+ b_h += tl.load(h0 + i_b * D + o_d, mask=mask, other=0).to(tl.float32)
75
+ for i in range(0, BT):
76
+ mask_t = mask & ((i_t * BT + i) < T)
77
+ b_x = tl.load(p_x, mask=mask_t, other=0).to(tl.float32)
78
+ b_g = tl.load(p_g, mask=mask_t, other=0).to(tl.float32)
79
+ b_h = exp(b_g) * b_h + b_x
80
+ b_gc = b_gc + b_g
81
+ tl.store(p_gc, b_gc.to(p_o.dtype.element_ty), mask=mask_t)
82
+ tl.store(p_o, b_h.to(p_o.dtype.element_ty), mask=mask_t)
83
+
84
+ p_x += D
85
+ p_g += D
86
+ p_gc += D
87
+ p_o += D
88
+
89
+
90
+ @triton.jit(do_not_specialize=['T'])
91
+ def chunk_hgrn_fwd_kernel_o(
92
+ gc,
93
+ o,
94
+ s_b,
95
+ s_t,
96
+ s_d,
97
+ T,
98
+ D: tl.constexpr,
99
+ BT: tl.constexpr,
100
+ BD: tl.constexpr
101
+ ):
102
+ i_d, i_b = tl.program_id(0), tl.program_id(1)
103
+ o_d = i_d * BD + tl.arange(0, BD)
104
+ mask = o_d < D
105
+
106
+ for i_t in range(1, tl.cdiv(T, BT)):
107
+ p_gc = tl.make_block_ptr(gc + i_b * s_b, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
108
+ p_o = tl.make_block_ptr(o + i_b * s_b, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
109
+
110
+ # [BD,]
111
+ b_h0 = tl.load(o + i_b * T * D + i_t * BT * D - D + o_d, mask=mask, other=0).to(tl.float32)
112
+ # [BT, BD]
113
+ b_gc = tl.load(p_gc, boundary_check=(0, 1)).to(tl.float32)
114
+ b_o = tl.load(p_o, boundary_check=(0, 1)).to(tl.float32)
115
+ b_o = b_o + exp(b_gc) * b_h0[None, :]
116
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
117
+
118
+
119
+ @triton.autotune(
120
+ configs=[
121
+ triton.Config({'BD': BD}, num_warps=num_warps)
122
+ for BD in [32, 64, 128]
123
+ for num_warps in [1, 2, 4, 8]
124
+ ],
125
+ key=['D']
126
+ )
127
+ @triton.jit(do_not_specialize=['T'])
128
+ def chunk_hgrn_bwd_kernel_h(
129
+ g,
130
+ gc,
131
+ dx,
132
+ do,
133
+ T,
134
+ D: tl.constexpr,
135
+ BT: tl.constexpr,
136
+ BD: tl.constexpr
137
+ ):
138
+ i_d, i_t, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2)
139
+ o_d = i_d * BD + tl.arange(0, BD)
140
+ mask = o_d < D
141
+ BC = min(BT, T - i_t * BT)
142
+ NT = tl.num_programs(1)
143
+
144
+ p_g = g + (i_b * T + i_t * BT + BC - 1) * D + o_d
145
+ p_gc = gc + (i_b * T + i_t * BT + BC - 1) * D + o_d
146
+ p_dx = dx + (i_b * T + i_t * BT + BC - 1) * D + o_d
147
+ p_do = do + (i_b * T + i_t * BT + BC - 1) * D + o_d
148
+
149
+ if i_t == NT - 1:
150
+ b_gc = tl.zeros([BD], dtype=tl.float32)
151
+ else:
152
+ b_gc = tl.load(g + (i_b * T + i_t * BT + BT) * D + o_d, mask=mask, other=0).to(tl.float32)
153
+ b_dh = tl.zeros([BD], dtype=tl.float32)
154
+ for _ in range(BC - 1, -1, -1):
155
+ tl.store(p_gc, b_gc.to(p_gc.dtype.element_ty), mask=mask)
156
+
157
+ b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32)
158
+ b_do = tl.load(p_do, mask=mask, other=0).to(tl.float32)
159
+
160
+ b_gc = b_gc + b_g
161
+ b_dh = b_dh + b_do
162
+ b_dx = b_dh
163
+ b_dh = b_dh * exp(b_g)
164
+
165
+ tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), mask=mask)
166
+
167
+ p_g -= D
168
+ p_gc -= D
169
+ p_dx -= D
170
+ p_do -= D
171
+
172
+
173
+ @triton.jit(do_not_specialize=['T'])
174
+ def chunk_hgrn_bwd_kernel_o(
175
+ g,
176
+ gc,
177
+ o,
178
+ dx,
179
+ dg,
180
+ s_b,
181
+ s_t,
182
+ s_d,
183
+ T,
184
+ D: tl.constexpr,
185
+ BT: tl.constexpr,
186
+ BD: tl.constexpr
187
+ ):
188
+ i_d, i_b = tl.program_id(0), tl.program_id(1)
189
+ o_d = i_d * BD + tl.arange(0, BD)
190
+ mask = o_d < D
191
+
192
+ for i_t in range(tl.cdiv(T, BT) - 1, -1, -1):
193
+ p_g = tl.make_block_ptr(g + i_b * s_b, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
194
+ p_gc = tl.make_block_ptr(gc + i_b * s_b, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
195
+ p_o = tl.make_block_ptr(o + i_b * s_b, (T, D), (s_t, s_d), (i_t * BT - 1, i_d * BD), (BT, BD), (1, 0))
196
+ p_dx = tl.make_block_ptr(dx + i_b * s_b, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
197
+ p_dg = tl.make_block_ptr(dg + i_b * s_b, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
198
+
199
+ # [BD,]
200
+ mask_t = mask & ((i_t + 1) * BT < T)
201
+ b_ht = tl.load(dx + i_b * T * D + (i_t + 1) * BT * D + o_d, mask=mask_t, other=0).to(tl.float32)
202
+ # [BT, BD]
203
+ b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32)
204
+ b_gc = tl.load(p_gc, boundary_check=(0, 1)).to(tl.float32)
205
+ b_o = tl.load(p_o, boundary_check=(0, 1)).to(tl.float32)
206
+ b_dx = tl.load(p_dx, boundary_check=(0, 1)).to(tl.float32)
207
+
208
+ b_dx = b_dx + exp(b_gc) * b_ht[None, :]
209
+ b_dg = b_o * b_dx * exp(b_g)
210
+ tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), boundary_check=(0, 1))
211
+ tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1))
212
+
213
+
214
+ class ChunkHGRNFunction(torch.autograd.Function):
215
+
216
+ @staticmethod
217
+ @input_guard
218
+ def forward(ctx, x, g, initial_state=None, output_final_state=False):
219
+ B, T, D = x.shape
220
+ BT, BD = 128, min(64, triton.next_power_of_2(D))
221
+ num_warps = 8 if BD == 64 else 4
222
+
223
+ gc = torch.empty_like(g, dtype=torch.float)
224
+ o = torch.empty_like(x, dtype=torch.float)
225
+ def grid(meta): return (triton.cdiv(D, meta['BD']), triton.cdiv(T, meta['BT']), B)
226
+ chunk_hgrn_fwd_kernel_h[grid](
227
+ x, g, gc, o, initial_state,
228
+ T=T, D=D, BT=BT,
229
+ USE_INITIAL_STATE=initial_state is not None
230
+ )
231
+ def grid(meta): return (triton.cdiv(D, meta['BD']), B)
232
+ chunk_hgrn_fwd_kernel_o[grid](
233
+ gc, o,
234
+ o.stride(-3), o.stride(-2), o.stride(-1),
235
+ T=T, D=D, BT=BT, BD=BD,
236
+ num_warps=num_warps
237
+ )
238
+ final_state = None
239
+ if output_final_state:
240
+ final_state = o[:, -1].clone()
241
+ o = o.to(x.dtype)
242
+ ctx.save_for_backward(g, o, initial_state)
243
+ return o, final_state
244
+
245
+ @staticmethod
246
+ @input_guard
247
+ def backward(ctx, do, dht=None):
248
+ g, o, initial_state = ctx.saved_tensors
249
+ B, T, D = do.shape
250
+ BT, BD = 128, min(64, triton.next_power_of_2(D))
251
+ num_warps = 8 if BD == 64 else 4
252
+
253
+ gc = torch.empty_like(g, dtype=torch.float)
254
+ dx = torch.empty_like(o, dtype=torch.float)
255
+ def grid(meta): return (triton.cdiv(D, meta['BD']), triton.cdiv(T, meta['BT']), B)
256
+ chunk_hgrn_bwd_kernel_h[grid](
257
+ g, gc, dx, do,
258
+ T=T, D=D, BT=BT
259
+ )
260
+
261
+ dg = torch.empty_like(g, dtype=torch.float)
262
+ def grid(meta): return (triton.cdiv(D, meta['BD']), B)
263
+ chunk_hgrn_bwd_kernel_o[grid](
264
+ g, gc, o, dx, dg,
265
+ o.stride(-3), o.stride(-2), o.stride(-1),
266
+ T=T, D=D, BT=BT, BD=BD,
267
+ num_warps=num_warps
268
+ )
269
+ if initial_state is not None:
270
+ dg[:, 0] = (initial_state * dx[:, 0] * g[:, 0].float().exp()).to(dg.dtype)
271
+
272
+ return dx.to(o.dtype), dg, None, None
273
+
274
+
275
+ @torch.compiler.disable
276
+ def chunk_hgrn(
277
+ x: torch.Tensor,
278
+ g: torch.Tensor,
279
+ initial_state: torch.Tensor = None,
280
+ output_final_state: bool = False
281
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
282
+ return ChunkHGRNFunction.apply(x, g, initial_state, output_final_state)
fla/ops/hgrn/fused_recurrent.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.utils.op import exp
11
+ from fla.utils import input_guard
12
+
13
+
14
+ @triton.heuristics({
15
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
16
+ 'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
17
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
18
+ })
19
+ @triton.autotune(
20
+ configs=[
21
+ triton.Config({'BD': BD}, num_warps=num_warps)
22
+ for BD in [32, 64, 128]
23
+ for num_warps in [1, 2, 4, 8]
24
+ ],
25
+ key=['D']
26
+ )
27
+ @triton.jit(do_not_specialize=['T'])
28
+ def fused_recurrent_hgrn_fwd_kernel(
29
+ x,
30
+ g,
31
+ o,
32
+ h0,
33
+ ht,
34
+ offsets,
35
+ T,
36
+ D: tl.constexpr,
37
+ BD: tl.constexpr,
38
+ USE_INITIAL_STATE: tl.constexpr,
39
+ STORE_FINAL_STATE: tl.constexpr,
40
+ USE_OFFSETS: tl.constexpr
41
+ ):
42
+ i_d, i_n = tl.program_id(0), tl.program_id(1)
43
+ if USE_OFFSETS:
44
+ bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64)
45
+ T = eos - bos
46
+ else:
47
+ bos, eos = i_n * T, i_n * T + T
48
+
49
+ o_d = i_d * BD + tl.arange(0, BD)
50
+ mask = o_d < D
51
+
52
+ p_x = x + bos * D + o_d
53
+ p_g = g + bos * D + o_d
54
+ p_o = o + bos * D + o_d
55
+
56
+ b_h = tl.zeros([BD], dtype=tl.float32)
57
+ if USE_INITIAL_STATE:
58
+ p_h0 = h0 + i_n * D + o_d
59
+ b_h += tl.load(p_h0, mask=mask, other=0).to(tl.float32)
60
+ for _ in range(0, T):
61
+ b_x = tl.load(p_x, mask=mask, other=0).to(tl.float32)
62
+ b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32)
63
+ b_h = exp(b_g) * b_h + b_x
64
+ tl.store(p_o, b_h.to(p_o.dtype.element_ty), mask=mask)
65
+
66
+ p_x += D
67
+ p_g += D
68
+ p_o += D
69
+
70
+ if STORE_FINAL_STATE:
71
+ p_ht = ht + i_n * D + o_d
72
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask)
73
+
74
+
75
+ @triton.heuristics({
76
+ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
77
+ 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None,
78
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
79
+ })
80
+ @triton.autotune(
81
+ configs=[
82
+ triton.Config({'BD': BD}, num_warps=num_warps)
83
+ for BD in [32, 64, 128]
84
+ for num_warps in [1, 2, 4, 8]
85
+ ],
86
+ key=['D']
87
+ )
88
+ @triton.jit(do_not_specialize=['T'])
89
+ def fused_recurrent_hgrn_bwd_kernel(
90
+ g,
91
+ o,
92
+ h0,
93
+ dx,
94
+ dg,
95
+ do,
96
+ dht,
97
+ dh0,
98
+ offsets,
99
+ T,
100
+ D: tl.constexpr,
101
+ BD: tl.constexpr,
102
+ USE_INITIAL_STATE: tl.constexpr,
103
+ USE_FINAL_STATE_GRADIENT: tl.constexpr,
104
+ USE_OFFSETS: tl.constexpr
105
+ ):
106
+ i_d, i_n = tl.program_id(0), tl.program_id(1)
107
+ if USE_OFFSETS:
108
+ bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64)
109
+ T = eos - bos
110
+ else:
111
+ bos, eos = i_n * T, i_n * T + T
112
+
113
+ o_d = i_d * BD + tl.arange(0, BD)
114
+ mask = o_d < D
115
+
116
+ p_g = g + (bos + T - 1) * D + o_d
117
+ p_o = o + (bos + T - 2) * D + o_d
118
+ p_dx = dx + (bos + T - 1) * D + o_d
119
+ p_dg = dg + (bos + T - 1) * D + o_d
120
+ p_do = do + (bos + T - 1) * D + o_d
121
+
122
+ b_dh = tl.zeros([BD], dtype=tl.float32)
123
+ if USE_FINAL_STATE_GRADIENT:
124
+ p_dht = dht + i_n * D + o_d
125
+ b_dh += tl.load(p_dht, mask=mask, other=0).to(tl.float32)
126
+
127
+ for i in range(T - 1, -1, -1):
128
+ b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32)
129
+ b_do = tl.load(p_do, mask=mask, other=0).to(tl.float32)
130
+ if i > 0:
131
+ b_o = tl.load(p_o, mask=mask, other=0).to(tl.float32)
132
+ elif USE_INITIAL_STATE:
133
+ b_o = tl.load(h0 + i_n * D + o_d, mask=mask, other=0).to(tl.float32)
134
+ else:
135
+ b_o = tl.zeros([BD], dtype=tl.float32)
136
+
137
+ b_dh = b_dh + b_do
138
+ b_dx = b_dh
139
+ b_dh = b_dh * exp(b_g)
140
+ b_dg = b_dh * b_o
141
+ tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), mask=mask)
142
+ tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), mask=mask)
143
+
144
+ p_g -= D
145
+ p_o -= D
146
+ p_dx -= D
147
+ p_dg -= D
148
+ p_do -= D
149
+
150
+ if USE_INITIAL_STATE:
151
+ p_dh0 = dh0 + i_n * D + o_d
152
+ tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), mask=mask)
153
+
154
+
155
+ def fused_recurrent_hgrn_fwd(
156
+ x: torch.Tensor,
157
+ g: torch.Tensor,
158
+ initial_state: torch.Tensor = None,
159
+ output_final_state: bool = False,
160
+ offsets: Optional[torch.LongTensor] = None,
161
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
162
+ B, T, D = x.shape
163
+ N = B if offsets is None else len(offsets) - 1
164
+
165
+ o = torch.empty_like(x)
166
+ final_state = x.new_empty(N, D) if output_final_state else None
167
+
168
+ def grid(meta): return (triton.cdiv(D, meta['BD']), N)
169
+ fused_recurrent_hgrn_fwd_kernel[grid](
170
+ x=x,
171
+ g=g,
172
+ o=o,
173
+ h0=initial_state,
174
+ ht=final_state,
175
+ offsets=offsets,
176
+ T=T,
177
+ D=D
178
+ )
179
+ return o, final_state
180
+
181
+
182
+ def fused_recurrent_hgrn_bwd(
183
+ g: torch.Tensor,
184
+ o: torch.Tensor,
185
+ do: torch.Tensor,
186
+ dht: torch.Tensor = None,
187
+ initial_state: torch.Tensor = None,
188
+ offsets: Optional[torch.LongTensor] = None
189
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
190
+ B, T, D = do.shape
191
+ N = B if offsets is None else len(offsets) - 1
192
+
193
+ dx = torch.empty_like(o, dtype=torch.float)
194
+ dg = torch.empty_like(g, dtype=torch.float)
195
+ dh0 = torch.empty_like(initial_state, dtype=torch.float) if initial_state is not None else None
196
+ def grid(meta): return (triton.cdiv(D, meta['BD']), N)
197
+ fused_recurrent_hgrn_bwd_kernel[grid](
198
+ g=g,
199
+ o=o,
200
+ h0=initial_state,
201
+ dx=dx,
202
+ dg=dg,
203
+ do=do,
204
+ dht=dht,
205
+ dh0=dh0,
206
+ offsets=offsets,
207
+ T=T,
208
+ D=D
209
+ )
210
+ return dx, dg, dh0
211
+
212
+
213
+ class FusedRecurrentHGRNFunction(torch.autograd.Function):
214
+
215
+ @staticmethod
216
+ @input_guard
217
+ def forward(
218
+ ctx,
219
+ x: torch.Tensor,
220
+ g: torch.Tensor,
221
+ initial_state: torch.Tensor = None,
222
+ output_final_state: bool = False,
223
+ offsets: Optional[torch.LongTensor] = None
224
+ ):
225
+ o, ht = fused_recurrent_hgrn_fwd(
226
+ x=x,
227
+ g=g,
228
+ initial_state=initial_state,
229
+ output_final_state=output_final_state,
230
+ offsets=offsets
231
+ )
232
+ ctx.save_for_backward(g, o, initial_state)
233
+ ctx.offsets = offsets
234
+ return o, ht
235
+
236
+ @staticmethod
237
+ @input_guard
238
+ def backward(ctx, do, dht=None):
239
+ g, o, initial_state = ctx.saved_tensors
240
+ offsets = ctx.offsets
241
+
242
+ dx, dg, dh0 = fused_recurrent_hgrn_bwd(
243
+ g=g,
244
+ o=o,
245
+ do=do,
246
+ dht=dht,
247
+ initial_state=initial_state,
248
+ offsets=offsets
249
+ )
250
+ return dx, dg, dh0, None, None
251
+
252
+
253
+ @torch.compiler.disable
254
+ def fused_recurrent_hgrn(
255
+ x: torch.Tensor,
256
+ g: torch.Tensor,
257
+ initial_state: torch.Tensor = None,
258
+ output_final_state: bool = False,
259
+ cu_seqlens: Optional[torch.LongTensor] = None,
260
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
261
+ r"""
262
+ Args:
263
+ x (torch.Tensor):
264
+ inputs of shape `[B, T, D].
265
+ g (torch.Tensor):
266
+ Forget gates of shape `[B, T, D]`.
267
+ initial_state (Optional[torch.Tensor]):
268
+ Initial state of shape `[N, D]` for `N` input sequences.
269
+ For equal-length input sequences, `N` equals the batch size `B`.
270
+ Default: `None`.
271
+ output_final_state (Optional[bool]):
272
+ Whether to output the final state of shape `[N, D]`. Default: `False`.
273
+ cu_seqlens (torch.LongTensor):
274
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
275
+ consistent with the FlashAttention API.
276
+
277
+ Returns:
278
+ o (torch.Tensor):
279
+ Outputs of shape `[B, T, D]`.
280
+ final_state (torch.Tensor):
281
+ Final state of shape `[N, D]` if `output_final_state=True` else `None`.
282
+
283
+ Examples::
284
+ >>> import torch
285
+ >>> import torch.nn.functional as F
286
+ >>> from einops import rearrange
287
+ >>> from fla.ops.hgrn import fused_recurrent_hgrn
288
+ # inputs with equal lengths
289
+ >>> B, T, D = 4, 2048, 512
290
+ >>> x = torch.randn(B, T, D, device='cuda')
291
+ >>> g = F.logsigmoid(torch.randn(B, T, D, device='cuda'))
292
+ >>> h0 = torch.randn(B, D, device='cuda')
293
+ >>> o, ht = fused_recurrent_hgrn(x, g, initial_state=h0, output_final_state=True)
294
+ # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
295
+ >>> x, g = map(lambda x: rearrange(x, 'b t d -> 1 (b t) d'), (x, g))
296
+ # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
297
+ >>> cu_seqlens = x.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
298
+ >>> o_var, ht_var = fused_recurrent_hgrn(x, g, initial_state=h0, output_final_state=True, cu_seqlens=cu_seqlens)
299
+ >>> assert o.allclose(o_var.view(o.shape))
300
+ >>> assert ht.allclose(ht_var)
301
+ """
302
+ return FusedRecurrentHGRNFunction.apply(
303
+ x,
304
+ g,
305
+ initial_state,
306
+ output_final_state,
307
+ cu_seqlens
308
+ )
fla/ops/lightning_attn/__pycache__/fused_recurrent.cpython-312.pyc ADDED
Binary file (3.75 kB). View file
 
fla/ops/linear_attn/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (365 Bytes). View file
 
fla/ops/linear_attn/__pycache__/utils.cpython-312.pyc ADDED
Binary file (554 Bytes). View file
 
fla/ops/linear_attn/fused_recurrent.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2024, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from fla.ops.linear_attn.utils import normalize_output
11
+ from fla.utils import input_guard
12
+
13
+
14
+ @triton.jit
15
+ def fused_recurrent_linear_attn_fwd_kernel(
16
+ q, # query [B, H, L, K]
17
+ k, # key [B, H, L, V]
18
+ v, # value [B, H, L, V]
19
+ o, # output [B, H, L, V]
20
+ h0,
21
+ ht, # final hidden state [B, H, K, V]
22
+
23
+ s_k_h, # stride size: L * K
24
+ s_v_h, # stride size: L * V
25
+
26
+ scale,
27
+ B, # batch size
28
+ H, # H
29
+ T, # T
30
+ K: tl.constexpr, # K
31
+ V: tl.constexpr, # V
32
+ BK: tl.constexpr, # BLOCK SIZE along the K dimension
33
+ BV: tl.constexpr, # BLOCK SIZE along the V dimension
34
+ USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
35
+ STORE_FINAL_STATE: tl.constexpr, # whether to store final state
36
+ ):
37
+ # indices
38
+ i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
39
+
40
+ p_q = q + i_bh * s_k_h + i_k * BK + tl.arange(0, BK)
41
+ p_k = k + i_bh * s_k_h + i_k * BK + tl.arange(0, BK)
42
+ p_v = v + i_bh * s_v_h + i_v * BV + tl.arange(0, BV)
43
+ p_o = o + (i_bh + i_k * B * H) * s_v_h + i_v * BV + tl.arange(0, BV)
44
+
45
+ mask_bk = (i_k * BK + tl.arange(0, BK)) < K
46
+ mask_bv = (i_v * BV + tl.arange(0, BV)) < V
47
+ mask_kv = mask_bk[None, :] & mask_bv[:, None]
48
+
49
+ b_h = tl.zeros([BV, BK], dtype=tl.float32)
50
+
51
+ if USE_INITIAL_STATE:
52
+ p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])
53
+ b_h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32)
54
+
55
+ for _ in range(0, T):
56
+ b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
57
+ b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
58
+ b_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale
59
+
60
+ b_h += b_k[None, :] * b_v[:, None]
61
+ b_o = b_h * b_q[None, :]
62
+ b_o = tl.sum(b_o, axis=1)
63
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_bv)
64
+
65
+ p_q += K
66
+ p_k += K
67
+ p_o += V
68
+ p_v += V
69
+
70
+ if STORE_FINAL_STATE:
71
+ p_ht = ht + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])
72
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_kv)
73
+
74
+
75
+ # Similar to Algorithm1 of https://arxiv.org/abs/2006.16236
76
+ @triton.jit
77
+ def fused_recurrent_linear_attn_bwd_kernel(
78
+ q, # query [B, H, L, K]
79
+ k, # key [B, H, L, V]
80
+ v, # value [B, H, L, V]
81
+
82
+ do, # gradient of output [B, H, L, V]
83
+ dq, # gradient of query [NV, B, H, L, K]
84
+ dk, # gradient of key [NV, B, H, L, K]
85
+ dv, # gradient of value [NK, B, H, L, V]
86
+ h0, # initial hidden state initialization [B, H, K, V]
87
+
88
+ s_k_h, # stride size: L * K
89
+ s_v_h, # stride size: L * V
90
+ scale, # K ** -0.5
91
+
92
+ B, # B
93
+ H, # H
94
+ T, # T
95
+ K: tl.constexpr, # K
96
+ V: tl.constexpr, # V
97
+ BK: tl.constexpr, # BLOCK SIZE along the K dimension
98
+ BV: tl.constexpr, # BLOCK SIZE along the V dimension
99
+ USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
100
+ ):
101
+ i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
102
+
103
+ p_q = q + i_bh * s_k_h + i_k * BK + tl.arange(0, BK)
104
+ p_k = k + i_bh * s_k_h + i_k * BK + tl.arange(0, BK)
105
+ p_v = v + i_bh * s_v_h + i_v * BV + tl.arange(0, BV)
106
+ p_do = do + i_bh * s_v_h + i_v * BV + tl.arange(0, BV)
107
+
108
+ p_dq = dq + (i_bh + i_v * B * H) * s_k_h + i_k * BK + tl.arange(0, BK)
109
+ mask_bk = i_k * BK + tl.arange(0, BK) < K
110
+ mask_bv = i_v * BV + tl.arange(0, BV) < V
111
+
112
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
113
+
114
+ if USE_INITIAL_STATE:
115
+ mask_kv = mask_bk[:, None] & mask_bv[None, :]
116
+ p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :])
117
+ b_h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32)
118
+
119
+ for _ in range(0, T):
120
+ b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
121
+ b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
122
+ b_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)
123
+
124
+ b_h += b_k[:, None] * b_v[None, :]
125
+ _d_q = b_h * b_do[None, :]
126
+ d_q = tl.sum(_d_q, axis=1) * scale
127
+ tl.store(p_dq, d_q.to(p_dq.dtype.element_ty), mask=mask_bk)
128
+
129
+ p_k += K
130
+ p_do += V
131
+ p_v += V
132
+ p_dq += K
133
+
134
+ # sync threads
135
+ tl.debug_barrier()
136
+
137
+ p_q = q + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + (T - 1) * K
138
+ p_k = k + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + (T - 1) * K
139
+ p_do = do + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + (T - 1) * V
140
+ p_v = v + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + (T - 1) * V
141
+ p_dk = dk + (i_bh + i_v * B * H) * s_k_h + i_k * BK + tl.arange(0, BK) + (T - 1) * K
142
+ p_dv = dv + (i_bh + i_k * B * H) * s_v_h + i_v * BV + tl.arange(0, BV) + (T - 1) * V
143
+ d_h = tl.zeros([BK, BV], dtype=tl.float32)
144
+
145
+ for _ in range(T):
146
+ b_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)
147
+ b_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale
148
+ b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)
149
+ b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)
150
+ d_h += b_q[:, None] * b_do[None, :]
151
+ d_k = tl.sum(d_h * b_v[None, :], axis=1)
152
+ d_v = tl.sum(d_h * b_k[:, None], axis=0)
153
+
154
+ tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_bk)
155
+ tl.store(p_dv, d_v.to(p_dv.dtype.element_ty), mask=mask_bv)
156
+
157
+ p_do -= V
158
+ p_q -= K
159
+ p_k -= K
160
+ p_v -= V
161
+ p_dk -= K
162
+ p_dv -= V
163
+
164
+
165
+ class FusedRecurrentLinearAttentionFunction(torch.autograd.Function):
166
+
167
+ @staticmethod
168
+ @input_guard
169
+ def forward(ctx, q, k, v, scale, initial_state=None, output_final_state=False):
170
+ B, H, T, K = q.shape
171
+ V = v.shape[-1]
172
+
173
+ BK, BV = min(K, 32), min(V, 32)
174
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
175
+ num_warps = 1
176
+ num_stages = 1
177
+
178
+ o = q.new_empty(NK, B, H, T, V)
179
+ final_state = q.new_empty(B, H, K, V) if output_final_state else None
180
+
181
+ grid = (NV, NK, B * H)
182
+ fused_recurrent_linear_attn_fwd_kernel[grid](
183
+ q, k, v, o, initial_state, final_state,
184
+ q.stride(1),
185
+ v.stride(1), scale,
186
+ B=B, H=H, T=T, K=K, V=V, BK=BK, BV=BV,
187
+ USE_INITIAL_STATE=initial_state is not None,
188
+ STORE_FINAL_STATE=final_state is not None,
189
+ num_warps=num_warps,
190
+ num_stages=num_stages
191
+ )
192
+
193
+ o = o.sum(0)
194
+ ctx.save_for_backward(q, k, v, initial_state)
195
+ ctx.scale = scale
196
+ return o, final_state
197
+
198
+ @staticmethod
199
+ @input_guard
200
+ def backward(ctx, do, dht=None):
201
+ q, k, v, initial_state = ctx.saved_tensors
202
+ B, H, T, K = q.shape
203
+ V = v.shape[-1]
204
+ scale = ctx.scale
205
+
206
+ BK, BV = min(K, 32), min(V, 32)
207
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
208
+ num_warps = 1
209
+ num_stages = 1
210
+
211
+ dq = q.new_empty(NV, B, H, T, K)
212
+ dk = q.new_empty(NV, B, H, T, K)
213
+ dv = q.new_empty(NK, B, H, T, V)
214
+ grid = (NV, NK, B * H)
215
+
216
+ fused_recurrent_linear_attn_bwd_kernel[grid](
217
+ q, k, v, do, dq, dk, dv, initial_state,
218
+ q.stride(1),
219
+ v.stride(1),
220
+ scale,
221
+ B=B, H=H, T=T, K=K, V=V, BK=BK, BV=BV,
222
+ USE_INITIAL_STATE=initial_state is not None,
223
+ num_warps=num_warps,
224
+ num_stages=num_stages
225
+ )
226
+ dq = dq.sum(0)
227
+ dk = dk.sum(0)
228
+ dv = dv.sum(0)
229
+ return dq, dk, dv, None, None, None
230
+
231
+
232
+ def fused_recurrent_linear_attn(
233
+ q: torch.Tensor,
234
+ k: torch.Tensor,
235
+ v: torch.Tensor,
236
+ scale: Optional[float] = None,
237
+ initial_state: torch.Tensor = None,
238
+ output_final_state: bool = False,
239
+ normalize: bool = False,
240
+ head_first: bool = True
241
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
242
+ if scale is None:
243
+ scale = q.shape[-1] ** -0.5
244
+ if not head_first:
245
+ q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
246
+ o, final_state = FusedRecurrentLinearAttentionFunction.apply(q, k, v, scale, initial_state, output_final_state)
247
+ if normalize:
248
+ o = normalize_output(q * scale, k, o)
249
+ if not head_first:
250
+ o = o.transpose(1, 2)
251
+ return o, final_state
fla/ops/nsa/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (268 Bytes). View file
 
fla/ops/nsa/__pycache__/naive.cpython-312.pyc ADDED
Binary file (5.82 kB). View file
 
fla/ops/nsa/__pycache__/parallel.cpython-312.pyc ADDED
Binary file (68.9 kB). View file
 
fla/ops/rebased/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (226 Bytes). View file
 
fla/ops/rebased/parallel.py ADDED
@@ -0,0 +1,466 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
4
+
5
+ import torch
6
+ import triton
7
+ import triton.language as tl
8
+
9
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard
10
+
11
+ # Rebased: Linear Transformers with Learnable Kernel Functions are Better In-Context Models
12
+ # https://github.com/corl-team/rebased/blob/main/flash_linear_attention/fla/ops/triton/rebased_fast/parallel.py
13
+
14
+
15
+ @triton.jit(do_not_specialize=['T'])
16
+ def parallel_rebased_fwd_kernel(
17
+ q,
18
+ k,
19
+ v,
20
+ o,
21
+ z,
22
+ scale,
23
+ T,
24
+ B: tl.constexpr,
25
+ H: tl.constexpr,
26
+ K: tl.constexpr,
27
+ V: tl.constexpr,
28
+ BTL: tl.constexpr,
29
+ BTS: tl.constexpr,
30
+ BK: tl.constexpr,
31
+ BV: tl.constexpr,
32
+ ):
33
+ # i_c: chunk index. used for sequence parallelism
34
+ i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
35
+ NV = tl.cdiv(V, BV)
36
+ i_k = i_kv // (NV)
37
+ i_v = i_kv % (NV)
38
+
39
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))
40
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k*BK, 0), (BK, BTS), (0, 1))
41
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (0, i_v*BV), (BTS, BV), (1, 0))
42
+
43
+ # [BQ, BD] block Q, in the shared memory throughout the whole kernel
44
+ b_q = tl.load(p_q, boundary_check=(0, 1))
45
+ b_q = (b_q * scale).to(b_q.dtype)
46
+ b_o = tl.zeros([BTL, BV], dtype=tl.float32)
47
+ b_z = tl.zeros([BTL], dtype=tl.float32)
48
+
49
+ # Q block and K block have no overlap
50
+ # no need for mask, thereby saving flops
51
+ for _ in range(0, i_c*BTL, BTS):
52
+ # [BK, BTS]
53
+ b_k = tl.load(p_k, boundary_check=(0, 1))
54
+
55
+ # [BTS, BV]
56
+ b_v = tl.load(p_v, boundary_check=(0, 1))
57
+ # [BTL, BTS]
58
+ b_s = tl.dot(b_q, (b_k), allow_tf32=False)
59
+ b_s = b_s * b_s
60
+ b_z += tl.sum(b_s, axis=1)
61
+
62
+ # [BQ, BD]
63
+ b_o = b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)
64
+ p_k = tl.advance(p_k, (0, BTS))
65
+ p_v = tl.advance(p_v, (BTS, 0))
66
+
67
+ # # rescale interchunk output
68
+ tl.debug_barrier()
69
+ o_q = tl.arange(0, BTL)
70
+ # # sync threads, easy for compiler to optimize
71
+ # tl.debug_barrier()
72
+
73
+ o_k = tl.arange(0, BTS)
74
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k*BK, i_c*BTL), (BK, BTS), (0, 1))
75
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_c*BTL, i_v*BV), (BTS, BV), (1, 0))
76
+ # Q block and K block have overlap. masks required
77
+ for _ in range(i_c*BTL, (i_c + 1) * BTL, BTS):
78
+ # [BK, BTS]
79
+ b_k = tl.load(p_k, boundary_check=(0, 1))
80
+ # [BTS, BV]
81
+ b_v = tl.load(p_v, boundary_check=(0, 1))
82
+ # [BTL, BTS]
83
+ m_s = o_q[:, None] >= o_k[None, :]
84
+ b_s = tl.dot(b_q, b_k, allow_tf32=False)
85
+ b_s = b_s * b_s
86
+ b_s = tl.where(m_s, b_s, 0)
87
+ b_z += tl.sum(b_s, axis=1)
88
+ # [BTL, BV]
89
+ b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)
90
+ p_k = tl.advance(p_k, (0, BTS))
91
+ p_v = tl.advance(p_v, (BTS, 0))
92
+ o_k += BTS
93
+
94
+ p_o = tl.make_block_ptr(o + (i_bh + B * H * i_k) * T*V, (T, V), (V, 1), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0))
95
+ p_z = z + (i_bh + B * H * i_k) * T + i_c*BTL + tl.arange(0, BTL)
96
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
97
+ tl.store(p_z, b_z.to(p_z.dtype.element_ty), mask=((i_c*BTL + tl.arange(0, BTL)) < T))
98
+
99
+
100
+ @triton.jit(do_not_specialize=['T'])
101
+ def _parallel_rebased_bwd_dq(
102
+ i_bh,
103
+ i_c,
104
+ i_k,
105
+ i_v,
106
+ i_h,
107
+ q,
108
+ k,
109
+ v,
110
+ do,
111
+ dz,
112
+ dq,
113
+ scale,
114
+ T,
115
+ B: tl.constexpr,
116
+ H: tl.constexpr,
117
+ K: tl.constexpr,
118
+ V: tl.constexpr,
119
+ BTL: tl.constexpr,
120
+ BTS: tl.constexpr,
121
+ BK: tl.constexpr,
122
+ BV: tl.constexpr
123
+ ):
124
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0))
125
+ p_q = tl.make_block_ptr(q + (i_bh) * T*K, (T, K), (K, 1), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))
126
+ b_q = tl.load(p_q, boundary_check=(0, 1))
127
+ b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)
128
+ b_q = (b_q * scale).to(b_q.dtype)
129
+ b_dq = tl.zeros([BTL, BK], dtype=tl.float32)
130
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (0, i_k*BK), (BTS, BK), (1, 0))
131
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (V, T), (1, V), (i_v*BV, 0), (BV, BTS), (0, 1))
132
+ p_dz = dz + i_bh * T + i_c*BTL + tl.arange(0, BTL)
133
+ b_dz = tl.load(p_dz, mask=(i_c*BTL + tl.arange(0, BTL)) < T)
134
+
135
+ for _ in range(0, i_c*BTL, BTS):
136
+ # [BTS, BK]
137
+ b_k = tl.load(p_k, boundary_check=(0, 1))
138
+ # [BV, BTS]
139
+ b_v = tl.load(p_v, boundary_check=(0, 1))
140
+ # [BTL, BTS]
141
+ b_ds = tl.dot(b_do, b_v, allow_tf32=False)
142
+ if i_v == 0:
143
+ b_ds += b_dz[:, None]
144
+ else:
145
+ b_ds = b_ds
146
+ b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False)
147
+ # [BQ, BD]
148
+ b_dq += tl.dot((2 * b_ds * b_s).to(b_v.dtype), b_k, allow_tf32=False)
149
+ p_k = tl.advance(p_k, (BTS, 0))
150
+ p_v = tl.advance(p_v, (0, BTS))
151
+
152
+ b_dq *= scale
153
+ o_q = tl.arange(0, BTL)
154
+ o_k = tl.arange(0, BTS)
155
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_c*BTL, i_k*BK), (BTS, BK), (1, 0))
156
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (V, T), (1, V), (i_v*BV, i_c*BTL), (BV, BTS), (0, 1))
157
+ # Q block and K block have overlap. masks required
158
+ for _ in range(i_c*BTL, (i_c + 1) * BTL, BTS):
159
+ # [BTS, BK]
160
+ b_k = tl.load(p_k, boundary_check=(0, 1))
161
+ # [BV, BTS]
162
+ b_v = tl.load(p_v, boundary_check=(0, 1))
163
+ # [BTL, BTS]
164
+ m_s = o_q[:, None] >= o_k[None, :]
165
+ b_ds = tl.dot(b_do, b_v, allow_tf32=False)
166
+ if i_v == 0:
167
+ b_ds += b_dz[:, None]
168
+ else:
169
+ b_ds = b_ds
170
+ b_ds = tl.where(m_s, b_ds, 0) * scale
171
+ b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False)
172
+ b_s = tl.where(m_s, b_s, 0)
173
+ # [BTL, BK]
174
+ b_dq += tl.dot((2 * b_ds * b_s).to(b_k.dtype),
175
+ b_k, allow_tf32=False)
176
+ p_k = tl.advance(p_k, (BTS, 0))
177
+ p_v = tl.advance(p_v, (0, BTS))
178
+ o_k += BTS
179
+ p_dq = tl.make_block_ptr(dq + (i_bh + B * H * i_v) * T*K, (T, K), (K, 1), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))
180
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
181
+ return
182
+
183
+
184
+ @triton.jit(do_not_specialize=['T'])
185
+ def _parallel_rebased_bwd_dkv(
186
+ i_bh,
187
+ i_c,
188
+ i_k,
189
+ i_v,
190
+ i_h,
191
+ q,
192
+ k,
193
+ v,
194
+ do,
195
+ dz,
196
+ dk,
197
+ dv,
198
+ scale,
199
+ T,
200
+ B: tl.constexpr,
201
+ H: tl.constexpr,
202
+ K: tl.constexpr,
203
+ V: tl.constexpr,
204
+ BTL: tl.constexpr,
205
+ BTS: tl.constexpr,
206
+ BK: tl.constexpr,
207
+ BV: tl.constexpr,
208
+ ):
209
+ # compute dk dv
210
+ p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))
211
+ p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0))
212
+ b_k, b_v = tl.load(p_k, boundary_check=(0, 1)), tl.load(p_v, boundary_check=(0, 1))
213
+ b_dk, b_dv = tl.zeros([BTL, BK], dtype=tl.float32), tl.zeros(
214
+ [BTL, BV], dtype=tl.float32)
215
+
216
+ for i in range((tl.cdiv(T, BTS) * BTS)-BTS, (i_c + 1) * BTL - BTS, -BTS):
217
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (K, T), (1, K), (i_k*BK, i), (BK, BTS), (0, 1))
218
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (V, T), (1, V), (i_v*BV, i), (BV, BTS), (0, 1))
219
+ p_dz = dz + i_bh * T + i + tl.arange(0, BTS)
220
+ # [BK, BTS]
221
+ b_q = tl.load(p_q, boundary_check=(0, 1))
222
+ # [BV, BTS]
223
+ b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)
224
+ b_dz = tl.load(p_dz, mask=(i + tl.arange(0, BTS)) < T)
225
+ # [BTL, BTS]
226
+ b_s = tl.dot(b_k.to(b_q.dtype), b_q, allow_tf32=False) * scale
227
+ b_s2 = b_s * b_s
228
+ b_dv += tl.dot(b_s2.to(b_q.dtype), tl.trans(b_do), allow_tf32=False)
229
+ b_ds = tl.dot(b_v, b_do, allow_tf32=False) * scale
230
+ if i_v == 0:
231
+ b_ds += b_dz[None, :] * scale
232
+ else:
233
+ b_ds = b_ds
234
+ b_dk += tl.dot((2 * b_ds * b_s).to(b_q.dtype), tl.trans(b_q), allow_tf32=False)
235
+
236
+ tl.debug_barrier()
237
+ o_q, o_k = tl.arange(0, BTS), tl.arange(0, BTL)
238
+ for i in range(i_c*BTL, (i_c+1)*BTL, BTS):
239
+ p_q = tl.make_block_ptr(q + i_bh * T*K, (K, T), (1, K), (i_k*BK, i), (BK, BTS), (0, 1))
240
+ p_do = tl.make_block_ptr(do + i_bh * T*V, (V, T), (1, V), (i_v*BV, i), (BV, BTS), (0, 1))
241
+ p_dz = dz + i_bh * T + i + tl.arange(0, BTS)
242
+ b_q = tl.load(p_q, boundary_check=(0, 1)) # [BD, BQ]
243
+ b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)
244
+ b_dz = tl.load(p_dz, mask=(i + tl.arange(0, BTS)) < T)
245
+ # [BK, BQ]
246
+ m_s = o_k[:, None] <= o_q[None, :]
247
+ b_s = tl.dot(b_k, b_q, allow_tf32=False) * scale
248
+ b_s2 = b_s * b_s
249
+ b_s = tl.where(m_s, b_s, 0)
250
+ b_s2 = tl.where(m_s, b_s2, 0)
251
+
252
+ b_ds = tl.dot(b_v, b_do, allow_tf32=False)
253
+ if i_v == 0:
254
+ b_ds += b_dz[None, :]
255
+ else:
256
+ b_ds = b_ds
257
+ b_ds = tl.where(m_s, b_ds, 0) * scale
258
+ # [BK, BD]
259
+ b_dv += tl.dot(b_s2.to(b_q.dtype), tl.trans(b_do), allow_tf32=False)
260
+ b_dk += tl.dot((2 * b_ds * b_s).to(b_q.dtype), tl.trans(b_q), allow_tf32=False)
261
+ o_q += BTS
262
+
263
+ p_dk = tl.make_block_ptr(dk + (i_bh + B * H * i_v) * T*K, (T, K), (K, 1), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))
264
+ p_dv = tl.make_block_ptr(dv + (i_bh + B * H * i_k) * T*V, (T, V), (V, 1), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0))
265
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
266
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
267
+ return
268
+
269
+
270
+ @triton.jit(do_not_specialize=['T'])
271
+ def parallel_rebased_bwd_kernel(
272
+ q,
273
+ k,
274
+ v,
275
+ do,
276
+ dz,
277
+ dq,
278
+ dk,
279
+ dv,
280
+ scale,
281
+ T,
282
+ B: tl.constexpr,
283
+ H: tl.constexpr,
284
+ K: tl.constexpr,
285
+ V: tl.constexpr,
286
+ BTL: tl.constexpr,
287
+ BTS: tl.constexpr,
288
+ BK: tl.constexpr,
289
+ BV: tl.constexpr
290
+ ):
291
+ i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
292
+ NV = tl.cdiv(V, BV)
293
+ i_k = i_kv // (NV)
294
+ i_v = i_kv % (NV)
295
+ i_h = i_bh % H
296
+ _parallel_rebased_bwd_dq(
297
+ i_bh,
298
+ i_c,
299
+ i_k,
300
+ i_v,
301
+ i_h,
302
+ q,
303
+ k,
304
+ v,
305
+ do,
306
+ dz,
307
+ dq,
308
+ scale,
309
+ B=B,
310
+ H=H,
311
+ T=T,
312
+ K=K,
313
+ V=V,
314
+ BTL=BTL,
315
+ BTS=BTS,
316
+ BK=BK,
317
+ BV=BV
318
+ )
319
+ tl.debug_barrier()
320
+ _parallel_rebased_bwd_dkv(
321
+ i_bh,
322
+ i_c,
323
+ i_k,
324
+ i_v,
325
+ i_h,
326
+ q,
327
+ k,
328
+ v,
329
+ do,
330
+ dz,
331
+ dk,
332
+ dv,
333
+ scale,
334
+ B=B,
335
+ H=H,
336
+ T=T,
337
+ K=K,
338
+ V=V,
339
+ BTL=BTL,
340
+ BTS=BTS,
341
+ BK=BK,
342
+ BV=BV
343
+ )
344
+
345
+
346
+ class ParallelBasedFunction(torch.autograd.Function):
347
+
348
+ @staticmethod
349
+ @input_guard
350
+ @autocast_custom_fwd
351
+ def forward(ctx, q, k, v, scale):
352
+ BTL, BTS = 128, 32
353
+ assert BTL % BTS == 0
354
+ # assert q.shape[-1] % 16 == 0
355
+ BK = min(128, triton.next_power_of_2(k.shape[-1]))
356
+ BV = min(128, triton.next_power_of_2(v.shape[-1]))
357
+ BK, BV = max(BK, 16), max(BV, 16)
358
+ B, H, T, K, V = *k.shape, v.shape[-1]
359
+ num_stages = 2
360
+ num_warps = 4
361
+ NK = triton.cdiv(K, BK)
362
+ NV = triton.cdiv(V, BV)
363
+ grid = (NK * NV, triton.cdiv(T, BTL), B * H)
364
+
365
+ assert NK == 1, "will encounter some synchronization issue if not."
366
+
367
+ o = torch.empty(NK, B, H, T, V, device=q.device)
368
+ z = torch.empty(NK, B, H, T, device=q.device)
369
+ parallel_rebased_fwd_kernel[grid](
370
+ q,
371
+ k,
372
+ v,
373
+ o,
374
+ z,
375
+ scale,
376
+ T=T,
377
+ B=B,
378
+ H=H,
379
+ K=K,
380
+ V=V,
381
+ BTL=BTL,
382
+ BTS=BTS,
383
+ BK=BK,
384
+ BV=BV,
385
+ num_warps=num_warps,
386
+ num_stages=num_stages
387
+ )
388
+ ctx.save_for_backward(q, k, v)
389
+ ctx.scale = scale
390
+ return o.sum(0).to(q.dtype), z.sum(0).to(q.dtype)
391
+
392
+ @staticmethod
393
+ @input_guard
394
+ @autocast_custom_bwd
395
+ def backward(ctx, do, dz):
396
+ q, k, v = ctx.saved_tensors
397
+ scale = ctx.scale
398
+ BTL, BTS = 64, 32
399
+ assert BTL % BTS == 0
400
+ BK = min(128, triton.next_power_of_2(k.shape[-1]))
401
+ BV = min(128, triton.next_power_of_2(v.shape[-1]))
402
+ BK, BV = max(BK, 16), max(BV, 16)
403
+ B, H, T, K, V = *k.shape, v.shape[-1]
404
+ num_stages = 2
405
+ num_warps = 4
406
+ NK = triton.cdiv(K, BK)
407
+ NV = triton.cdiv(V, BV)
408
+ grid = (NK * NV, triton.cdiv(T, BTL), B * H)
409
+
410
+ assert NK == 1, "will encounter some synchronization issue if not"
411
+
412
+ dq = torch.empty(NV, B, H, T, K, dtype=q.dtype, device=q.device)
413
+ dk = torch.empty(NV, B, H, T, K, dtype=q.dtype, device=q.device)
414
+ dv = torch.empty(NK, B, H, T, V, dtype=q.dtype, device=q.device)
415
+
416
+ parallel_rebased_bwd_kernel[grid](
417
+ q,
418
+ k,
419
+ v,
420
+ do,
421
+ dz,
422
+ dq,
423
+ dk,
424
+ dv,
425
+ scale,
426
+ T=T,
427
+ B=B,
428
+ H=H,
429
+ K=K,
430
+ V=V,
431
+ BTL=BTL,
432
+ BTS=BTS,
433
+ BK=BK,
434
+ BV=BV,
435
+ num_warps=num_warps,
436
+ num_stages=num_stages
437
+ )
438
+
439
+ return dq.sum(0).to(q.dtype), dk.sum(0).to(k.dtype), dv.sum(0).to(v.dtype), None
440
+
441
+
442
+ def parallel_rebased(
443
+ q: torch.Tensor,
444
+ k: torch.Tensor,
445
+ v: torch.Tensor,
446
+ eps: float = 1e-5,
447
+ use_scale: bool = True,
448
+ use_normalize: bool = True,
449
+ return_both: bool = False,
450
+ head_first: bool = True
451
+ ):
452
+ assert q.shape[-1] <= 128, "only support feature dim up to 128"
453
+ if use_scale:
454
+ scale = q.shape[-1] ** -0.5
455
+ else:
456
+ scale = 1
457
+ if not head_first:
458
+ q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
459
+ o, z = ParallelBasedFunction.apply(q, k, v, scale)
460
+ if return_both:
461
+ return o, z
462
+ if use_normalize:
463
+ o = o / (z[..., None] + eps)
464
+ if not head_first:
465
+ o = o.transpose(1, 2)
466
+ return o.to(q.dtype)
fla/ops/retention/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (414 Bytes). View file