danieldk HF Staff commited on
Commit
c880638
·
1 Parent(s): 8e88928

Revert "Build uploaded using `kernels`."

Browse files

This reverts commit 8e88928ad44de1c64da65f0b82088ab3af4cdb49.

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. build/torch210-cxx11-cu126-x86_64-linux/__init__.py +14 -0
  2. build/torch210-cxx11-cu126-x86_64-linux/_mamba_ssm_b2a7fd5.abi3.so +3 -0
  3. build/torch210-cxx11-cu126-x86_64-linux/_ops.py +9 -0
  4. build/torch210-cxx11-cu126-x86_64-linux/distributed/__init__.py +0 -0
  5. build/torch210-cxx11-cu126-x86_64-linux/distributed/distributed_utils.py +144 -0
  6. build/torch210-cxx11-cu126-x86_64-linux/distributed/tensor_parallel.py +296 -0
  7. build/torch210-cxx11-cu126-x86_64-linux/mamba_ssm/__init__.py +26 -0
  8. build/torch210-cxx11-cu126-x86_64-linux/metadata.json +1 -0
  9. build/torch210-cxx11-cu126-x86_64-linux/models/__init__.py +0 -0
  10. build/torch210-cxx11-cu126-x86_64-linux/models/config_mamba.py +18 -0
  11. build/torch210-cxx11-cu126-x86_64-linux/models/mixer_seq_simple.py +309 -0
  12. build/torch210-cxx11-cu126-x86_64-linux/modules/__init__.py +0 -0
  13. build/torch210-cxx11-cu126-x86_64-linux/modules/block.py +107 -0
  14. build/torch210-cxx11-cu126-x86_64-linux/modules/mamba2.py +502 -0
  15. build/torch210-cxx11-cu126-x86_64-linux/modules/mamba2_simple.py +229 -0
  16. build/torch210-cxx11-cu126-x86_64-linux/modules/mamba_simple.py +339 -0
  17. build/torch210-cxx11-cu126-x86_64-linux/modules/mha.py +294 -0
  18. build/torch210-cxx11-cu126-x86_64-linux/modules/mlp.py +34 -0
  19. build/torch210-cxx11-cu126-x86_64-linux/modules/ssd_minimal.py +111 -0
  20. build/torch210-cxx11-cu126-x86_64-linux/ops/__init__.py +0 -0
  21. build/torch210-cxx11-cu126-x86_64-linux/ops/selective_scan_interface.py +446 -0
  22. build/torch210-cxx11-cu126-x86_64-linux/ops/triton/__init__.py +0 -0
  23. build/torch210-cxx11-cu126-x86_64-linux/ops/triton/k_activations.py +169 -0
  24. build/torch210-cxx11-cu126-x86_64-linux/ops/triton/layer_norm.py +1113 -0
  25. build/torch210-cxx11-cu126-x86_64-linux/ops/triton/layernorm_gated.py +437 -0
  26. build/torch210-cxx11-cu126-x86_64-linux/ops/triton/selective_state_update.py +285 -0
  27. build/torch210-cxx11-cu126-x86_64-linux/ops/triton/softplus.py +15 -0
  28. build/torch210-cxx11-cu126-x86_64-linux/ops/triton/ssd_bmm.py +262 -0
  29. build/torch210-cxx11-cu126-x86_64-linux/ops/triton/ssd_chunk_scan.py +0 -0
  30. build/torch210-cxx11-cu126-x86_64-linux/ops/triton/ssd_chunk_state.py +997 -0
  31. build/torch210-cxx11-cu126-x86_64-linux/ops/triton/ssd_combined.py +998 -0
  32. build/torch210-cxx11-cu126-x86_64-linux/ops/triton/ssd_state_passing.py +348 -0
  33. build/torch210-cxx11-cu126-x86_64-linux/utils/__init__.py +0 -0
  34. build/torch210-cxx11-cu126-x86_64-linux/utils/generation.py +390 -0
  35. build/torch210-cxx11-cu126-x86_64-linux/utils/hf.py +23 -0
  36. build/torch210-cxx11-cu126-x86_64-linux/utils/torch.py +21 -0
  37. build/torch210-cxx11-cu128-x86_64-linux/__init__.py +14 -0
  38. build/torch210-cxx11-cu128-x86_64-linux/_mamba_ssm_b2a7fd5.abi3.so +3 -0
  39. build/torch210-cxx11-cu128-x86_64-linux/_ops.py +9 -0
  40. build/torch210-cxx11-cu128-x86_64-linux/distributed/__init__.py +0 -0
  41. build/torch210-cxx11-cu128-x86_64-linux/distributed/distributed_utils.py +144 -0
  42. build/torch210-cxx11-cu128-x86_64-linux/distributed/tensor_parallel.py +296 -0
  43. build/torch210-cxx11-cu128-x86_64-linux/mamba_ssm/__init__.py +26 -0
  44. build/torch210-cxx11-cu128-x86_64-linux/metadata.json +1 -0
  45. build/torch210-cxx11-cu128-x86_64-linux/models/__init__.py +0 -0
  46. build/torch210-cxx11-cu128-x86_64-linux/models/config_mamba.py +18 -0
  47. build/torch210-cxx11-cu128-x86_64-linux/models/mixer_seq_simple.py +309 -0
  48. build/torch210-cxx11-cu128-x86_64-linux/modules/__init__.py +0 -0
  49. build/torch210-cxx11-cu128-x86_64-linux/modules/block.py +107 -0
  50. build/torch210-cxx11-cu128-x86_64-linux/modules/mamba2.py +502 -0
build/torch210-cxx11-cu126-x86_64-linux/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __version__ = "2.2.4"
2
+
3
+ from .ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
4
+ from .modules.mamba_simple import Mamba
5
+ from .modules.mamba2 import Mamba2
6
+ from .models.mixer_seq_simple import MambaLMHeadModel
7
+
8
+ __all__ = [
9
+ "selective_scan_fn",
10
+ "mamba_inner_fn",
11
+ "Mamba",
12
+ "Mamba2",
13
+ "MambaLMHeadModel",
14
+ ]
build/torch210-cxx11-cu126-x86_64-linux/_mamba_ssm_b2a7fd5.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:19b5ffd35a9fd55231325ac14270580c019395c0acb3e4e251518042b50b1aed
3
+ size 444257200
build/torch210-cxx11-cu126-x86_64-linux/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _mamba_ssm_b2a7fd5
3
+ ops = torch.ops._mamba_ssm_b2a7fd5
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_mamba_ssm_b2a7fd5::{op_name}"
build/torch210-cxx11-cu126-x86_64-linux/distributed/__init__.py ADDED
File without changes
build/torch210-cxx11-cu126-x86_64-linux/distributed/distributed_utils.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ from torch import Tensor
5
+ from torch.distributed import ProcessGroup
6
+
7
+ # `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
8
+ # `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
9
+ # version of PyTorch. The following 4 lines are for backward compatibility with
10
+ # older PyTorch.
11
+ if "all_gather_into_tensor" not in dir(torch.distributed):
12
+ torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
13
+ if "reduce_scatter_tensor" not in dir(torch.distributed):
14
+ torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base
15
+
16
+
17
+ # Raw operation, does not support autograd, but does support async
18
+ def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
19
+ world_size = torch.distributed.get_world_size(process_group)
20
+ output = torch.empty(
21
+ world_size * input_.shape[0], *input_.shape[1:], dtype=input_.dtype, device=input_.device
22
+ )
23
+ handle = torch.distributed.all_gather_into_tensor(
24
+ output, input_.contiguous(), group=process_group, async_op=async_op
25
+ )
26
+ return output, handle
27
+
28
+
29
+ # Raw operation, does not support autograd, but does support async
30
+ def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
31
+ world_size = torch.distributed.get_world_size(process_group)
32
+ assert input_.shape[0] % world_size == 0
33
+ output = torch.empty(
34
+ input_.shape[0] // world_size, *input_.shape[1:], dtype=input_.dtype, device=input_.device
35
+ )
36
+ handle = torch.distributed.reduce_scatter_tensor(
37
+ output, input_.contiguous(), group=process_group, async_op=async_op
38
+ )
39
+ return output, handle
40
+
41
+
42
+ # Raw operation, does not support autograd, but does support async
43
+ def all_reduce_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
44
+ input_ = input_.contiguous()
45
+ handle = torch.distributed.all_reduce(input_, group=process_group, async_op=async_op)
46
+ return input_, handle
47
+
48
+
49
+ class AllGatherFunc(torch.autograd.Function):
50
+ """Gather the input from sequence parallel region and concatenate."""
51
+
52
+ @staticmethod
53
+ def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
54
+ ctx.process_group = process_group
55
+ output, _ = all_gather_raw(input_, process_group)
56
+ return output
57
+
58
+ @staticmethod
59
+ def backward(ctx, grad_output: Tensor):
60
+ grad_input, _ = reduce_scatter_raw(grad_output, ctx.process_group)
61
+ return grad_input, None
62
+
63
+
64
+ # Supports autograd, but does not support async
65
+ all_gather = AllGatherFunc.apply
66
+
67
+
68
+ class ReduceScatterFunc(torch.autograd.Function):
69
+ """Reduce scatter the input from the sequence parallel region and concatenate."""
70
+
71
+ @staticmethod
72
+ def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
73
+ ctx.process_group = process_group
74
+ output, _ = reduce_scatter_raw(input_, process_group)
75
+ return output
76
+
77
+ @staticmethod
78
+ def backward(ctx, grad_output: Tensor):
79
+ grad_input, _ = all_gather_raw(grad_output, ctx.process_group)
80
+ return grad_input, None
81
+
82
+
83
+ # Supports autograd, but does not support async
84
+ reduce_scatter = ReduceScatterFunc.apply
85
+
86
+
87
+ class AllReduceFunc(torch.autograd.Function):
88
+ """Gather the input from sequence parallel region and concatenate."""
89
+
90
+ @staticmethod
91
+ def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
92
+ ctx.process_group = process_group
93
+ output, _ = all_reduce_raw(input_, process_group)
94
+ return output
95
+
96
+ @staticmethod
97
+ def backward(ctx, grad_output: Tensor):
98
+ return grad_output, None
99
+
100
+
101
+ # Supports autograd, but does not support async
102
+ all_reduce = AllReduceFunc.apply
103
+
104
+
105
+ def sync_shared_params(model: torch.nn.Module, process_group: ProcessGroup):
106
+ # We want to iterate over parameters with _shared_params=True in the same order,
107
+ # as different ranks might have different number of parameters (e.g., only rank 0 has bias).
108
+ pamams_shared = {
109
+ name: p for name, p in model.named_parameters() if getattr(p, "_shared_params", False)
110
+ }
111
+ for _, p in sorted(pamams_shared.items()):
112
+ with torch.no_grad():
113
+ # Broadcast needs src to be global rank, not group rank
114
+ torch.distributed.broadcast(
115
+ p, src=torch.distributed.get_global_rank(process_group, 0), group=process_group
116
+ )
117
+
118
+
119
+ # Ref: https://github.com/NVIDIA/Megatron-LM/blob/52e636888cccc41e931251c417a7181fc36de926/megatron/optimizer/optimizer.py#L256
120
+ def allreduce_sequence_parallel_grad(model: torch.nn.Module, process_group: ProcessGroup):
121
+ # We want to iterate over parameters with _sequence_parallel=True in the same order,
122
+ # as different ranks might have different number of parameters (e.g., only rank 0 has bias).
123
+ params_seqparallel = {
124
+ name: p for name, p in model.named_parameters() if getattr(p, "_sequence_parallel", False)
125
+ }
126
+ grads = [p.grad for _, p in sorted(params_seqparallel.items())]
127
+ if grads:
128
+ with torch.no_grad():
129
+ coalesced = torch._utils._flatten_dense_tensors(grads)
130
+ torch.distributed.all_reduce(coalesced, group=process_group)
131
+ for buf, synced in zip(grads, torch._utils._unflatten_dense_tensors(coalesced, grads)):
132
+ buf.copy_(synced)
133
+
134
+
135
+ def get_dim_for_local_rank(dim: int, world_size: int, local_rank: int, multiple_of: int = 1) -> int:
136
+ """Get the dim for the local rank derived from splitting dim on world_size processes.
137
+
138
+ The split may not be even across the world_size processes.
139
+ """
140
+ multiple = dim // multiple_of
141
+ div = multiple // world_size
142
+ mod = multiple % world_size
143
+ local_multiple = div + int(local_rank < mod)
144
+ return local_multiple * multiple_of
build/torch210-cxx11-cu126-x86_64-linux/distributed/tensor_parallel.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao.
2
+ # The TensorParallel linear modules are inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/layers.py
3
+ from typing import Optional
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from torch import Tensor
9
+ from torch.distributed import ProcessGroup
10
+ from ..utils.torch import custom_bwd, custom_fwd
11
+
12
+ from einops import rearrange
13
+
14
+ from ..distributed.distributed_utils import (
15
+ all_gather_raw,
16
+ all_reduce,
17
+ all_reduce_raw,
18
+ reduce_scatter,
19
+ reduce_scatter_raw,
20
+ )
21
+
22
+
23
+ class ParallelLinearFunc(torch.autograd.Function):
24
+ @staticmethod
25
+ @custom_fwd
26
+ def forward(ctx, x, weight, bias, process_group=None, sequence_parallel=True):
27
+ """
28
+ If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel
29
+ with sequence parallelism: we do an all_gather_raw of x before doing the matmul.
30
+ """
31
+ ctx.compute_weight_gradient = weight.requires_grad
32
+ ctx.process_group = process_group
33
+ ctx.sequence_parallel = sequence_parallel
34
+
35
+ if torch.is_autocast_enabled():
36
+ x = x.to(dtype=torch.get_autocast_gpu_dtype())
37
+ x = x.contiguous()
38
+ if process_group is not None and sequence_parallel:
39
+ # We want to kick off the all_gather early, before weight dtype conversion
40
+ total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
41
+ else:
42
+ total_x = x
43
+
44
+ if torch.is_autocast_enabled():
45
+ weight = weight.to(dtype=torch.get_autocast_gpu_dtype())
46
+ bias = bias.to(dtype=torch.get_autocast_gpu_dtype()) if bias is not None else None
47
+ weight = weight.contiguous()
48
+ if process_group is not None and sequence_parallel:
49
+ handle_x.wait()
50
+ batch_shape, n = total_x.shape[:-1], total_x.shape[-1]
51
+ batch_dim = batch_shape.numel()
52
+ # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174
53
+ output = F.linear(total_x, weight, bias)
54
+ if ctx.compute_weight_gradient:
55
+ ctx.save_for_backward(x, weight)
56
+ else:
57
+ ctx.save_for_backward(weight)
58
+ return output
59
+
60
+ @staticmethod
61
+ @custom_bwd
62
+ def backward(ctx, grad_output):
63
+ grad_output = grad_output.contiguous()
64
+ process_group = ctx.process_group
65
+ sequence_parallel = ctx.sequence_parallel
66
+ if ctx.compute_weight_gradient:
67
+ x, weight = ctx.saved_tensors
68
+ if process_group is not None and sequence_parallel:
69
+ total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
70
+ else:
71
+ total_x = x
72
+ else:
73
+ (weight,) = ctx.saved_tensors
74
+ total_x = None
75
+ batch_shape = grad_output.shape[:-1]
76
+ batch_dim = batch_shape.numel()
77
+ grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
78
+ if ctx.needs_input_grad[0]:
79
+ grad_input = F.linear(grad_output, weight.t())
80
+ grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
81
+ if process_group is not None:
82
+ reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw
83
+ grad_input, handle_grad_input = reduce_fn(grad_input, process_group, async_op=True)
84
+ else:
85
+ grad_input = None
86
+ if ctx.needs_input_grad[1]:
87
+ assert ctx.compute_weight_gradient
88
+ if process_group is not None and sequence_parallel:
89
+ handle_x.wait()
90
+ grad_weight = torch.einsum(
91
+ "bo,bi->oi", grad_output, total_x.reshape(batch_dim, total_x.shape[-1])
92
+ )
93
+ else:
94
+ grad_weight = None
95
+ grad_bias = grad_output.sum(dim=0) if ctx.needs_input_grad[2] else None
96
+ if process_group is not None and ctx.needs_input_grad[0]:
97
+ handle_grad_input.wait()
98
+ return grad_input, grad_weight, grad_bias, None, None
99
+
100
+
101
+ def parallel_linear_func(
102
+ x: Tensor,
103
+ weight: Tensor,
104
+ bias: Optional[Tensor] = None,
105
+ process_group: Optional[ProcessGroup] = None,
106
+ sequence_parallel: bool = True,
107
+ ):
108
+ return ParallelLinearFunc.apply(x, weight, bias, process_group, sequence_parallel)
109
+
110
+
111
+ class ColumnParallelLinear(nn.Linear):
112
+ def __init__(
113
+ self,
114
+ in_features: int,
115
+ out_features: int,
116
+ process_group: ProcessGroup,
117
+ bias: bool = True,
118
+ sequence_parallel=True,
119
+ multiple_of=1,
120
+ device=None,
121
+ dtype=None,
122
+ ) -> None:
123
+ world_size = torch.distributed.get_world_size(process_group)
124
+ if out_features % multiple_of:
125
+ raise ValueError(f"out_features ({out_features}) must be a multiple of {multiple_of}")
126
+ multiple = out_features // multiple_of
127
+ # We want to split @multiple across world_size, but it could be an uneven split
128
+ div = multiple // world_size
129
+ mod = multiple % world_size
130
+ # The first @mod ranks get @div + 1 copies, the rest get @div copies
131
+ local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)
132
+ super().__init__(
133
+ in_features, local_multiple * multiple_of, bias=bias, device=device, dtype=dtype
134
+ )
135
+ self.process_group = process_group
136
+ self.sequence_parallel = sequence_parallel
137
+
138
+ def forward(self, x):
139
+ # If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
140
+ # we do an all_gather of x before doing the matmul.
141
+ # If not, then the input is already gathered.
142
+ return parallel_linear_func(
143
+ x,
144
+ self.weight,
145
+ self.bias,
146
+ process_group=self.process_group,
147
+ sequence_parallel=self.sequence_parallel,
148
+ )
149
+
150
+
151
+ class RowParallelLinear(nn.Linear):
152
+ def __init__(
153
+ self,
154
+ in_features: int,
155
+ out_features: int,
156
+ process_group: ProcessGroup,
157
+ bias: bool = True,
158
+ sequence_parallel=True,
159
+ multiple_of=1,
160
+ device=None,
161
+ dtype=None,
162
+ ) -> None:
163
+ world_size = torch.distributed.get_world_size(process_group)
164
+ rank = torch.distributed.get_rank(process_group)
165
+ if in_features % multiple_of:
166
+ raise ValueError(f"in_features ({in_features}) must be a multiple of {multiple_of}")
167
+ multiple = in_features // multiple_of
168
+ # We want to split @multiple across world_size, but it could be an uneven split
169
+ div = multiple // world_size
170
+ mod = multiple % world_size
171
+ # The first @mod ranks get @div + 1 copies, the rest get @div copies
172
+ local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)
173
+ # Only rank 0 will have bias
174
+ super().__init__(
175
+ local_multiple * multiple_of,
176
+ out_features,
177
+ bias=bias and rank == 0,
178
+ device=device,
179
+ dtype=dtype,
180
+ )
181
+ self.process_group = process_group
182
+ self.sequence_parallel = sequence_parallel
183
+
184
+ def forward(self, x):
185
+ """
186
+ We're doing Tensor Parallel with sequence parallelism: we do the matmul and then
187
+ a reduce_scatter of the result.
188
+ """
189
+ out = parallel_linear_func(x, self.weight, self.bias)
190
+ reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
191
+ return reduce_fn(out, self.process_group)
192
+
193
+
194
+ class VocabParallelEmbedding(nn.Embedding):
195
+ def __init__(self, num_embeddings, *args, process_group=None, padding_idx=None, **kwargs):
196
+ self.process_group = process_group
197
+ if process_group is not None:
198
+ world_size = torch.distributed.get_world_size(process_group)
199
+ if num_embeddings % world_size != 0:
200
+ raise ValueError(
201
+ f"num_embeddings ({num_embeddings}) must be divisible by "
202
+ f"world_size ({world_size})"
203
+ )
204
+ if world_size > 1 and padding_idx is not None:
205
+ raise RuntimeError("ParallelEmbedding does not support padding_idx")
206
+ else:
207
+ world_size = 1
208
+ super().__init__(num_embeddings // world_size, *args, padding_idx=padding_idx, **kwargs)
209
+
210
+ def forward(self, input: Tensor) -> Tensor:
211
+ if self.process_group is None:
212
+ return super().forward(input)
213
+ else:
214
+ rank = torch.distributed.get_rank(self.process_group)
215
+ vocab_size = self.num_embeddings
216
+ vocab_start_index, vocab_end_index = rank * vocab_size, (rank + 1) * vocab_size
217
+ # Create a mask of valid vocab ids (1 means it needs to be masked).
218
+ input_ids_mask = (input < vocab_start_index) | (input >= vocab_end_index)
219
+ input = input - vocab_start_index
220
+ input[input_ids_mask] = 0
221
+ embeddings = super().forward(input)
222
+ embeddings[input_ids_mask] = 0.0
223
+ return embeddings
224
+
225
+
226
+ class ColumnParallelEmbedding(nn.Embedding):
227
+ def __init__(self, num_embeddings, embedding_dim, *args, process_group=None, **kwargs):
228
+ self.process_group = process_group
229
+ if process_group is not None:
230
+ world_size = torch.distributed.get_world_size(process_group)
231
+ if embedding_dim % world_size != 0:
232
+ raise ValueError(
233
+ f"embedding_dim ({embedding_dim}) must be divisible by "
234
+ f"world_size ({world_size})"
235
+ )
236
+ else:
237
+ world_size = 1
238
+ super().__init__(num_embeddings, embedding_dim // world_size, *args, **kwargs)
239
+
240
+
241
+ class ParallelEmbeddings(nn.Module):
242
+ def __init__(
243
+ self,
244
+ embed_dim,
245
+ vocab_size,
246
+ max_position_embeddings,
247
+ process_group,
248
+ padding_idx=None,
249
+ sequence_parallel=True,
250
+ device=None,
251
+ dtype=None,
252
+ ):
253
+ """
254
+ If max_position_embeddings <= 0, there's no position embeddings
255
+ """
256
+ factory_kwargs = {"device": device, "dtype": dtype}
257
+ super().__init__()
258
+ self.process_group = process_group
259
+ self.sequence_parallel = sequence_parallel
260
+ self.word_embeddings = VocabParallelEmbedding(
261
+ vocab_size,
262
+ embed_dim,
263
+ padding_idx=padding_idx,
264
+ process_group=process_group,
265
+ **factory_kwargs,
266
+ )
267
+ self.max_position_embeddings = max_position_embeddings
268
+ if self.max_position_embeddings > 0:
269
+ self.position_embeddings = ColumnParallelEmbedding(
270
+ max_position_embeddings, embed_dim, process_group=process_group, **factory_kwargs
271
+ )
272
+
273
+ def forward(self, input_ids, position_ids=None, combine_batch_seqlen_dim=False):
274
+ """
275
+ input_ids: (batch, seqlen)
276
+ position_ids: (batch, seqlen)
277
+ """
278
+ batch_size, seqlen = input_ids.shape
279
+ world_size = torch.distributed.get_world_size(self.process_group)
280
+ embeddings = self.word_embeddings(input_ids)
281
+ if self.max_position_embeddings > 0:
282
+ if position_ids is None:
283
+ position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
284
+ position_embeddings = self.position_embeddings(position_ids)
285
+ if world_size <= 1:
286
+ embeddings = embeddings + position_embeddings
287
+ else:
288
+ partition_dim = self.position_embeddings.embedding_dim
289
+ rank = torch.distributed.get_rank(self.process_group)
290
+ embeddings[
291
+ ..., rank * partition_dim : (rank + 1) * partition_dim
292
+ ] += position_embeddings
293
+ if combine_batch_seqlen_dim:
294
+ embeddings = rearrange(embeddings, "b s d -> (b s) d")
295
+ reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
296
+ return embeddings if world_size <= 1 else reduce_fn(embeddings, self.process_group)
build/torch210-cxx11-cu126-x86_64-linux/mamba_ssm/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ctypes
2
+ import sys
3
+
4
+ import importlib
5
+ from pathlib import Path
6
+ from types import ModuleType
7
+
8
+ def _import_from_path(file_path: Path) -> ModuleType:
9
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
10
+ # it would also be used for other imports. So, we make a module name that
11
+ # depends on the path for it to be unique using the hex-encoded hash of
12
+ # the path.
13
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
+ module_name = path_hash
15
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
16
+ if spec is None:
17
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
+ module = importlib.util.module_from_spec(spec)
19
+ if module is None:
20
+ raise ImportError(f"Cannot load module {module_name} from spec")
21
+ sys.modules[module_name] = module
22
+ spec.loader.exec_module(module) # type: ignore
23
+ return module
24
+
25
+
26
+ globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
build/torch210-cxx11-cu126-x86_64-linux/metadata.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"python-depends":[]}
build/torch210-cxx11-cu126-x86_64-linux/models/__init__.py ADDED
File without changes
build/torch210-cxx11-cu126-x86_64-linux/models/config_mamba.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+
3
+
4
+ @dataclass
5
+ class MambaConfig:
6
+
7
+ d_model: int = 2560
8
+ d_intermediate: int = 0
9
+ n_layer: int = 64
10
+ vocab_size: int = 50277
11
+ ssm_cfg: dict = field(default_factory=dict)
12
+ attn_layer_idx: list = field(default_factory=list)
13
+ attn_cfg: dict = field(default_factory=dict)
14
+ rms_norm: bool = True
15
+ residual_in_fp32: bool = True
16
+ fused_add_norm: bool = True
17
+ pad_vocab_size_multiple: int = 8
18
+ tie_embeddings: bool = True
build/torch210-cxx11-cu126-x86_64-linux/models/mixer_seq_simple.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Albert Gu, Tri Dao.
2
+
3
+ import math
4
+ from functools import partial
5
+ import json
6
+ import os
7
+ import copy
8
+
9
+ from collections import namedtuple
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+
14
+ from .config_mamba import MambaConfig
15
+ from ..modules.mamba_simple import Mamba
16
+ from ..modules.mamba2 import Mamba2
17
+ from ..modules.mha import MHA
18
+ from ..modules.mlp import GatedMLP
19
+ from ..modules.block import Block
20
+ from ..utils.generation import GenerationMixin
21
+ from ..utils.hf import load_config_hf, load_state_dict_hf
22
+
23
+ try:
24
+ from ..ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
25
+ except ImportError:
26
+ RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
27
+
28
+
29
+ def create_block(
30
+ d_model,
31
+ d_intermediate,
32
+ ssm_cfg=None,
33
+ attn_layer_idx=None,
34
+ attn_cfg=None,
35
+ norm_epsilon=1e-5,
36
+ rms_norm=False,
37
+ residual_in_fp32=False,
38
+ fused_add_norm=False,
39
+ layer_idx=None,
40
+ device=None,
41
+ dtype=None,
42
+ ):
43
+ if ssm_cfg is None:
44
+ ssm_cfg = {}
45
+ if attn_layer_idx is None:
46
+ attn_layer_idx = []
47
+ if attn_cfg is None:
48
+ attn_cfg = {}
49
+ factory_kwargs = {"device": device, "dtype": dtype}
50
+ if layer_idx not in attn_layer_idx:
51
+ # Create a copy of the config to modify
52
+ ssm_cfg = copy.deepcopy(ssm_cfg) if ssm_cfg is not None else {}
53
+ ssm_layer = ssm_cfg.pop("layer", "Mamba1")
54
+ if ssm_layer not in ["Mamba1", "Mamba2"]:
55
+ raise ValueError(f"Invalid ssm_layer: {ssm_layer}, only support Mamba1 and Mamba2")
56
+ mixer_cls = partial(
57
+ Mamba2 if ssm_layer == "Mamba2" else Mamba,
58
+ layer_idx=layer_idx,
59
+ **ssm_cfg,
60
+ **factory_kwargs
61
+ )
62
+ else:
63
+ mixer_cls = partial(MHA, layer_idx=layer_idx, **attn_cfg, **factory_kwargs)
64
+ norm_cls = partial(
65
+ nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
66
+ )
67
+ if d_intermediate == 0:
68
+ mlp_cls = nn.Identity
69
+ else:
70
+ mlp_cls = partial(
71
+ GatedMLP, hidden_features=d_intermediate, out_features=d_model, **factory_kwargs
72
+ )
73
+ block = Block(
74
+ d_model,
75
+ mixer_cls,
76
+ mlp_cls,
77
+ norm_cls=norm_cls,
78
+ fused_add_norm=fused_add_norm,
79
+ residual_in_fp32=residual_in_fp32,
80
+ )
81
+ block.layer_idx = layer_idx
82
+ return block
83
+
84
+
85
+ # https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
86
+ def _init_weights(
87
+ module,
88
+ n_layer,
89
+ initializer_range=0.02, # Now only used for embedding layer.
90
+ rescale_prenorm_residual=True,
91
+ n_residuals_per_layer=1, # Change to 2 if we have MLP
92
+ ):
93
+ if isinstance(module, nn.Linear):
94
+ if module.bias is not None:
95
+ if not getattr(module.bias, "_no_reinit", False):
96
+ nn.init.zeros_(module.bias)
97
+ elif isinstance(module, nn.Embedding):
98
+ nn.init.normal_(module.weight, std=initializer_range)
99
+
100
+ if rescale_prenorm_residual:
101
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
102
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
103
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
104
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
105
+ #
106
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
107
+ for name, p in module.named_parameters():
108
+ if name in ["out_proj.weight", "fc2.weight"]:
109
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
110
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
111
+ # We need to reinit p since this code could be called multiple times
112
+ # Having just p *= scale would repeatedly scale it down
113
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
114
+ with torch.no_grad():
115
+ p /= math.sqrt(n_residuals_per_layer * n_layer)
116
+
117
+
118
+ class MixerModel(nn.Module):
119
+ def __init__(
120
+ self,
121
+ d_model: int,
122
+ n_layer: int,
123
+ d_intermediate: int,
124
+ vocab_size: int,
125
+ ssm_cfg=None,
126
+ attn_layer_idx=None,
127
+ attn_cfg=None,
128
+ norm_epsilon: float = 1e-5,
129
+ rms_norm: bool = False,
130
+ initializer_cfg=None,
131
+ fused_add_norm=False,
132
+ residual_in_fp32=False,
133
+ device=None,
134
+ dtype=None,
135
+ ) -> None:
136
+ factory_kwargs = {"device": device, "dtype": dtype}
137
+ super().__init__()
138
+ self.residual_in_fp32 = residual_in_fp32
139
+
140
+ self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs)
141
+
142
+ # We change the order of residual and layer norm:
143
+ # Instead of LN -> Attn / MLP -> Add, we do:
144
+ # Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
145
+ # the main branch (output of MLP / Mixer). The model definition is unchanged.
146
+ # This is for performance reason: we can fuse add + layer_norm.
147
+ self.fused_add_norm = fused_add_norm
148
+ if self.fused_add_norm:
149
+ if layer_norm_fn is None or rms_norm_fn is None:
150
+ raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")
151
+
152
+ self.layers = nn.ModuleList(
153
+ [
154
+ create_block(
155
+ d_model,
156
+ d_intermediate=d_intermediate,
157
+ ssm_cfg=ssm_cfg,
158
+ attn_layer_idx=attn_layer_idx,
159
+ attn_cfg=attn_cfg,
160
+ norm_epsilon=norm_epsilon,
161
+ rms_norm=rms_norm,
162
+ residual_in_fp32=residual_in_fp32,
163
+ fused_add_norm=fused_add_norm,
164
+ layer_idx=i,
165
+ **factory_kwargs,
166
+ )
167
+ for i in range(n_layer)
168
+ ]
169
+ )
170
+
171
+ self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(
172
+ d_model, eps=norm_epsilon, **factory_kwargs
173
+ )
174
+
175
+ self.apply(
176
+ partial(
177
+ _init_weights,
178
+ n_layer=n_layer,
179
+ **(initializer_cfg if initializer_cfg is not None else {}),
180
+ n_residuals_per_layer=1 if d_intermediate == 0 else 2, # 2 if we have MLP
181
+ )
182
+ )
183
+
184
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
185
+ return {
186
+ i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
187
+ for i, layer in enumerate(self.layers)
188
+ }
189
+
190
+ def forward(self, input_ids, inference_params=None, **mixer_kwargs):
191
+ hidden_states = self.embedding(input_ids)
192
+ residual = None
193
+ for layer in self.layers:
194
+ hidden_states, residual = layer(
195
+ hidden_states, residual, inference_params=inference_params, **mixer_kwargs
196
+ )
197
+ if not self.fused_add_norm:
198
+ residual = (hidden_states + residual) if residual is not None else hidden_states
199
+ hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
200
+ else:
201
+ # Set prenorm=False here since we don't need the residual
202
+ hidden_states = layer_norm_fn(
203
+ hidden_states,
204
+ self.norm_f.weight,
205
+ self.norm_f.bias,
206
+ eps=self.norm_f.eps,
207
+ residual=residual,
208
+ prenorm=False,
209
+ residual_in_fp32=self.residual_in_fp32,
210
+ is_rms_norm=isinstance(self.norm_f, RMSNorm)
211
+ )
212
+ return hidden_states
213
+
214
+
215
+ class MambaLMHeadModel(nn.Module, GenerationMixin):
216
+
217
+ def __init__(
218
+ self,
219
+ config: MambaConfig,
220
+ initializer_cfg=None,
221
+ device=None,
222
+ dtype=None,
223
+ ) -> None:
224
+ self.config = config
225
+ d_model = config.d_model
226
+ n_layer = config.n_layer
227
+ d_intermediate = config.d_intermediate
228
+ vocab_size = config.vocab_size
229
+ ssm_cfg = config.ssm_cfg
230
+ attn_layer_idx = config.attn_layer_idx
231
+ attn_cfg = config.attn_cfg
232
+ rms_norm = config.rms_norm
233
+ residual_in_fp32 = config.residual_in_fp32
234
+ fused_add_norm = config.fused_add_norm
235
+ pad_vocab_size_multiple = config.pad_vocab_size_multiple
236
+ factory_kwargs = {"device": device, "dtype": dtype}
237
+
238
+ super().__init__()
239
+ if vocab_size % pad_vocab_size_multiple != 0:
240
+ vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)
241
+ self.backbone = MixerModel(
242
+ d_model=d_model,
243
+ n_layer=n_layer,
244
+ d_intermediate=d_intermediate,
245
+ vocab_size=vocab_size,
246
+ ssm_cfg=ssm_cfg,
247
+ attn_layer_idx=attn_layer_idx,
248
+ attn_cfg=attn_cfg,
249
+ rms_norm=rms_norm,
250
+ initializer_cfg=initializer_cfg,
251
+ fused_add_norm=fused_add_norm,
252
+ residual_in_fp32=residual_in_fp32,
253
+ **factory_kwargs,
254
+ )
255
+ self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
256
+
257
+ # Initialize weights and apply final processing
258
+ self.apply(
259
+ partial(
260
+ _init_weights,
261
+ n_layer=n_layer,
262
+ **(initializer_cfg if initializer_cfg is not None else {}),
263
+ )
264
+ )
265
+ self.tie_weights()
266
+
267
+ def tie_weights(self):
268
+ if self.config.tie_embeddings:
269
+ self.lm_head.weight = self.backbone.embedding.weight
270
+
271
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
272
+ return self.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
273
+
274
+ def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0, **mixer_kwargs):
275
+ """
276
+ "position_ids" is just to be compatible with Transformer generation. We don't use it.
277
+ num_last_tokens: if > 0, only return the logits for the last n tokens
278
+ """
279
+ hidden_states = self.backbone(input_ids, inference_params=inference_params, **mixer_kwargs)
280
+ if num_last_tokens > 0:
281
+ hidden_states = hidden_states[:, -num_last_tokens:]
282
+ lm_logits = self.lm_head(hidden_states)
283
+ CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
284
+ return CausalLMOutput(logits=lm_logits)
285
+
286
+ @classmethod
287
+ def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
288
+ config_data = load_config_hf(pretrained_model_name)
289
+ config = MambaConfig(**config_data)
290
+ model = cls(config, device=device, dtype=dtype, **kwargs)
291
+ model.load_state_dict(load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype))
292
+ return model
293
+
294
+ def save_pretrained(self, save_directory):
295
+ """
296
+ Minimal implementation of save_pretrained for MambaLMHeadModel.
297
+ Save the model and its configuration file to a directory.
298
+ """
299
+ # Ensure save_directory exists
300
+ os.makedirs(save_directory, exist_ok=True)
301
+
302
+ # Save the model's state_dict
303
+ model_path = os.path.join(save_directory, 'pytorch_model.bin')
304
+ torch.save(self.state_dict(), model_path)
305
+
306
+ # Save the configuration of the model
307
+ config_path = os.path.join(save_directory, 'config.json')
308
+ with open(config_path, 'w') as f:
309
+ json.dump(self.config.__dict__, f, indent=4)
build/torch210-cxx11-cu126-x86_64-linux/modules/__init__.py ADDED
File without changes
build/torch210-cxx11-cu126-x86_64-linux/modules/block.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao, Albert Gu.
2
+ from typing import Optional
3
+
4
+ import torch
5
+ from torch import nn, Tensor
6
+
7
+ from ..ops.triton.layer_norm import RMSNorm, layer_norm_fn
8
+
9
+
10
+ class Block(nn.Module):
11
+ def __init__(
12
+ self,
13
+ dim,
14
+ mixer_cls,
15
+ mlp_cls,
16
+ norm_cls=nn.LayerNorm,
17
+ fused_add_norm=False,
18
+ residual_in_fp32=False,
19
+ ):
20
+ """
21
+ Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection"
22
+
23
+ This Block has a slightly different structure compared to a regular
24
+ prenorm Transformer block.
25
+ The standard block is: LN -> MHA/MLP -> Add.
26
+ [Ref: https://arxiv.org/abs/2002.04745]
27
+ Here we have: Add -> LN -> Mixer, returning both
28
+ the hidden_states (output of the mixer) and the residual.
29
+ This is purely for performance reasons, as we can fuse add and LayerNorm.
30
+ The residual needs to be provided (except for the very first block).
31
+ """
32
+ super().__init__()
33
+ self.residual_in_fp32 = residual_in_fp32
34
+ self.fused_add_norm = fused_add_norm
35
+ self.norm = norm_cls(dim)
36
+ self.mixer = mixer_cls(dim)
37
+ if mlp_cls is not nn.Identity:
38
+ self.norm2 = norm_cls(dim)
39
+ self.mlp = mlp_cls(dim)
40
+ else:
41
+ self.mlp = None
42
+ if self.fused_add_norm:
43
+ assert RMSNorm is not None, "RMSNorm import fails"
44
+ assert isinstance(
45
+ self.norm, (nn.LayerNorm, RMSNorm)
46
+ ), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
47
+
48
+ def forward(
49
+ self,
50
+ hidden_states: Tensor,
51
+ residual: Optional[Tensor] = None,
52
+ inference_params=None,
53
+ **mixer_kwargs
54
+ ):
55
+ r"""Pass the input through the encoder layer.
56
+
57
+ Args:
58
+ hidden_states: the sequence to the encoder layer (required).
59
+ residual: hidden_states = Mixer(LN(residual))
60
+ """
61
+ if not self.fused_add_norm:
62
+ residual = (
63
+ (hidden_states + residual) if residual is not None else hidden_states
64
+ )
65
+ hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
66
+ if self.residual_in_fp32:
67
+ residual = residual.to(torch.float32)
68
+ else:
69
+ hidden_states, residual = layer_norm_fn(
70
+ hidden_states,
71
+ self.norm.weight,
72
+ self.norm.bias,
73
+ residual=residual,
74
+ prenorm=True,
75
+ residual_in_fp32=self.residual_in_fp32,
76
+ eps=self.norm.eps,
77
+ is_rms_norm=isinstance(self.norm, RMSNorm),
78
+ )
79
+ hidden_states = self.mixer(
80
+ hidden_states, inference_params=inference_params, **mixer_kwargs
81
+ )
82
+
83
+ if self.mlp is not None:
84
+ if not self.fused_add_norm:
85
+ residual = hidden_states + residual
86
+ hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
87
+ if self.residual_in_fp32:
88
+ residual = residual.to(torch.float32)
89
+ else:
90
+ hidden_states, residual = layer_norm_fn(
91
+ hidden_states,
92
+ self.norm2.weight,
93
+ self.norm2.bias,
94
+ residual=residual,
95
+ prenorm=True,
96
+ residual_in_fp32=self.residual_in_fp32,
97
+ eps=self.norm2.eps,
98
+ is_rms_norm=isinstance(self.norm2, RMSNorm),
99
+ )
100
+ hidden_states = self.mlp(hidden_states)
101
+
102
+ return hidden_states, residual
103
+
104
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
105
+ return self.mixer.allocate_inference_cache(
106
+ batch_size, max_seqlen, dtype=dtype, **kwargs
107
+ )
build/torch210-cxx11-cu126-x86_64-linux/modules/mamba2.py ADDED
@@ -0,0 +1,502 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao, Albert Gu.
2
+
3
+ import math
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from einops import rearrange, repeat
10
+
11
+ try:
12
+ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
13
+ except ImportError:
14
+ causal_conv1d_fn, causal_conv1d_update = None, None
15
+
16
+ try:
17
+ from causal_conv1d.causal_conv1d_varlen import causal_conv1d_varlen_states
18
+ except ImportError:
19
+ causal_conv1d_varlen_states = None
20
+
21
+ try:
22
+ from ..ops.triton.selective_state_update import selective_state_update
23
+ except ImportError:
24
+ selective_state_update = None
25
+
26
+ from ..ops.triton.layernorm_gated import RMSNorm as RMSNormGated
27
+
28
+ from ..distributed.tensor_parallel import ColumnParallelLinear, RowParallelLinear
29
+ from ..distributed.distributed_utils import all_reduce, reduce_scatter
30
+
31
+ from ..ops.triton.ssd_combined import mamba_chunk_scan_combined
32
+ from ..ops.triton.ssd_combined import mamba_split_conv1d_scan_combined
33
+
34
+ from huggingface_hub import PyTorchModelHubMixin
35
+
36
+
37
+ class Mamba2(nn.Module, PyTorchModelHubMixin):
38
+ def __init__(
39
+ self,
40
+ d_model,
41
+ d_state=128,
42
+ d_conv=4,
43
+ conv_init=None,
44
+ expand=2,
45
+ headdim=64,
46
+ d_ssm=None, # If not None, we only apply SSM on this many dimensions, the rest uses gated MLP
47
+ ngroups=1,
48
+ A_init_range=(1, 16),
49
+ D_has_hdim=False,
50
+ rmsnorm=True,
51
+ norm_before_gate=False,
52
+ dt_min=0.001,
53
+ dt_max=0.1,
54
+ dt_init_floor=1e-4,
55
+ dt_limit=(0.0, float("inf")),
56
+ bias=False,
57
+ conv_bias=True,
58
+ # Fused kernel and sharding options
59
+ chunk_size=256,
60
+ use_mem_eff_path=True,
61
+ layer_idx=None, # Absorb kwarg for general module
62
+ process_group=None,
63
+ sequence_parallel=True,
64
+ device=None,
65
+ dtype=None,
66
+ ):
67
+ factory_kwargs = {"device": device, "dtype": dtype}
68
+ super().__init__()
69
+ self.d_model = d_model
70
+ self.d_state = d_state
71
+ self.d_conv = d_conv
72
+ self.conv_init = conv_init
73
+ self.expand = expand
74
+ self.process_group = process_group
75
+ self.sequence_parallel = sequence_parallel
76
+ self.world_size = 1 if process_group is None else process_group.size()
77
+ self.local_rank = 0 if process_group is None else process_group.rank()
78
+ self.d_inner = (self.expand * self.d_model) // self.world_size
79
+ assert self.d_inner * self.world_size == self.expand * self.d_model
80
+ self.headdim = headdim
81
+ self.d_ssm = self.d_inner if d_ssm is None else d_ssm // self.world_size
82
+ assert ngroups % self.world_size == 0
83
+ self.ngroups = ngroups // self.world_size
84
+ assert self.d_ssm % self.headdim == 0
85
+ self.nheads = self.d_ssm // self.headdim
86
+ self.D_has_hdim = D_has_hdim
87
+ self.rmsnorm = rmsnorm
88
+ self.norm_before_gate = norm_before_gate
89
+ self.dt_limit = dt_limit
90
+ self.activation = "silu"
91
+ self.chunk_size = chunk_size
92
+ self.use_mem_eff_path = use_mem_eff_path
93
+ self.layer_idx = layer_idx
94
+
95
+ # Order: [z, x, B, C, dt]
96
+ d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
97
+ if self.process_group is None:
98
+ self.in_proj = nn.Linear(
99
+ self.d_model, d_in_proj, bias=bias, **factory_kwargs
100
+ )
101
+ else:
102
+ self.in_proj = ColumnParallelLinear(
103
+ self.d_model,
104
+ d_in_proj * self.world_size,
105
+ bias=bias,
106
+ process_group=self.process_group,
107
+ sequence_parallel=self.sequence_parallel,
108
+ **factory_kwargs,
109
+ )
110
+
111
+ conv_dim = self.d_ssm + 2 * self.ngroups * self.d_state
112
+ self.conv1d = nn.Conv1d(
113
+ in_channels=conv_dim,
114
+ out_channels=conv_dim,
115
+ bias=conv_bias,
116
+ kernel_size=d_conv,
117
+ groups=conv_dim,
118
+ padding=d_conv - 1,
119
+ **factory_kwargs,
120
+ )
121
+ if self.conv_init is not None:
122
+ nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init)
123
+
124
+ self.act = nn.SiLU()
125
+
126
+ # Initialize log dt bias
127
+ dt = torch.exp(
128
+ torch.rand(self.nheads, **factory_kwargs)
129
+ * (math.log(dt_max) - math.log(dt_min))
130
+ + math.log(dt_min)
131
+ )
132
+ dt = torch.clamp(dt, min=dt_init_floor)
133
+ # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
134
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
135
+ self.dt_bias = nn.Parameter(inv_dt)
136
+ # Just to be explicit. Without this we already don't put wd on dt_bias because of the check
137
+ # name.endswith("bias") in param_grouping.py
138
+ self.dt_bias._no_weight_decay = True
139
+
140
+ assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0]
141
+ A = torch.empty(self.nheads, dtype=torch.float32, device=device).uniform_(
142
+ *A_init_range
143
+ )
144
+ A_log = torch.log(A).to(dtype=dtype)
145
+ self.A_log = nn.Parameter(A_log)
146
+ self.A_log._no_weight_decay = True
147
+
148
+ # D "skip" parameter
149
+ self.D = nn.Parameter(
150
+ torch.ones(self.d_ssm if self.D_has_hdim else self.nheads, device=device)
151
+ )
152
+ self.D._no_weight_decay = True
153
+
154
+ if self.rmsnorm:
155
+ assert RMSNormGated is not None
156
+ self.norm = RMSNormGated(
157
+ self.d_ssm,
158
+ eps=1e-5,
159
+ norm_before_gate=self.norm_before_gate,
160
+ group_size=self.d_ssm // ngroups,
161
+ **factory_kwargs,
162
+ )
163
+
164
+ if self.process_group is None:
165
+ self.out_proj = nn.Linear(
166
+ self.d_inner, self.d_model, bias=bias, **factory_kwargs
167
+ )
168
+ else:
169
+ self.out_proj = RowParallelLinear(
170
+ self.d_inner * self.world_size,
171
+ self.d_model,
172
+ bias=bias,
173
+ process_group=self.process_group,
174
+ sequence_parallel=self.sequence_parallel,
175
+ **factory_kwargs,
176
+ )
177
+
178
+ def forward(
179
+ self, u, seqlen=None, seq_idx=None, cu_seqlens=None, inference_params=None
180
+ ):
181
+ """
182
+ u: (batch, seqlen, hidden_dim) if seqlen=None.
183
+ If seqlen is not None, u is (batch * seqlen, hidden_dim). This is so that when we
184
+ split u during sequence parallel, we split the batch * seqlen dimension
185
+ (in case batch is small).
186
+ Returns: same shape as u
187
+ """
188
+ seqlen_og = seqlen
189
+ if seqlen is None:
190
+ batch, seqlen, dim = u.shape
191
+ else:
192
+ batch_seqlen, dim = u.shape
193
+ batch = batch_seqlen // seqlen
194
+
195
+ conv_state, ssm_state = None, None
196
+ if inference_params is not None:
197
+ inference_batch = (
198
+ cu_seqlens.shape[0] - 1 if cu_seqlens is not None else batch
199
+ )
200
+ conv_state, ssm_state = self._get_states_from_cache(
201
+ inference_params, inference_batch
202
+ )
203
+ if inference_params.seqlen_offset > 0:
204
+ # The states are updated inplace
205
+ out, _, _ = self.step(u, conv_state, ssm_state)
206
+ return out
207
+
208
+ zxbcdt = self.in_proj(u) # (B, L, d_in_proj) or (B * L, d_in_proj)
209
+ if seqlen_og is not None:
210
+ zxbcdt = rearrange(zxbcdt, "(b l) d -> b l d", l=seqlen)
211
+ # If the model is loaded in fp16, without the .float() here, A might be -inf
212
+ A = -torch.exp(self.A_log.float()) # (nheads) or (d_inner, d_state)
213
+ dt_limit_kwargs = (
214
+ {} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit)
215
+ )
216
+ if self.use_mem_eff_path and inference_params is None:
217
+ out = mamba_split_conv1d_scan_combined(
218
+ zxbcdt,
219
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
220
+ self.conv1d.bias,
221
+ self.dt_bias,
222
+ A,
223
+ D=(
224
+ rearrange(self.D, "(h p) -> h p", p=self.headdim)
225
+ if self.D_has_hdim
226
+ else self.D
227
+ ),
228
+ chunk_size=self.chunk_size,
229
+ seq_idx=seq_idx,
230
+ activation=self.activation,
231
+ rmsnorm_weight=self.norm.weight if self.rmsnorm else None,
232
+ rmsnorm_eps=self.norm.eps if self.rmsnorm else 1e-6,
233
+ outproj_weight=self.out_proj.weight,
234
+ outproj_bias=self.out_proj.bias,
235
+ headdim=None if self.D_has_hdim else self.headdim,
236
+ ngroups=self.ngroups,
237
+ norm_before_gate=self.norm_before_gate,
238
+ **dt_limit_kwargs,
239
+ )
240
+ if seqlen_og is not None:
241
+ out = rearrange(out, "b l d -> (b l) d")
242
+ if self.process_group is not None:
243
+ reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
244
+ out = reduce_fn(out, self.process_group)
245
+ else:
246
+ d_mlp = (
247
+ zxbcdt.shape[-1]
248
+ - 2 * self.d_ssm
249
+ - 2 * self.ngroups * self.d_state
250
+ - self.nheads
251
+ ) // 2
252
+ z0, x0, z, xBC, dt = torch.split(
253
+ zxbcdt,
254
+ [
255
+ d_mlp,
256
+ d_mlp,
257
+ self.d_ssm,
258
+ self.d_ssm + 2 * self.ngroups * self.d_state,
259
+ self.nheads,
260
+ ],
261
+ dim=-1,
262
+ )
263
+ if conv_state is not None:
264
+ if cu_seqlens is None:
265
+ # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
266
+ # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
267
+ xBC_t = rearrange(xBC, "b l d -> b d l")
268
+ conv_state.copy_(
269
+ F.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0))
270
+ ) # Update state (B D W)
271
+ else:
272
+ assert (
273
+ causal_conv1d_varlen_states is not None
274
+ ), "varlen inference requires causal_conv1d package"
275
+ assert (
276
+ batch == 1
277
+ ), "varlen inference only supports batch dimension 1"
278
+ conv_varlen_states = causal_conv1d_varlen_states(
279
+ xBC.squeeze(0), cu_seqlens, state_len=conv_state.shape[-1]
280
+ )
281
+ conv_state.copy_(conv_varlen_states)
282
+ assert self.activation in ["silu", "swish"]
283
+ if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]:
284
+ assert (
285
+ seq_idx is None
286
+ ), "varlen conv1d requires the causal_conv1d package"
287
+ xBC = self.act(
288
+ self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)[
289
+ :, : -(self.d_conv - 1)
290
+ ]
291
+ ) # (B, L, self.d_ssm + 2 * ngroups * d_state)
292
+ else:
293
+ xBC = causal_conv1d_fn(
294
+ xBC.transpose(1, 2),
295
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
296
+ bias=self.conv1d.bias,
297
+ activation=self.activation,
298
+ seq_idx=seq_idx,
299
+ ).transpose(1, 2)
300
+ x, B, C = torch.split(
301
+ xBC,
302
+ [self.d_ssm, self.ngroups * self.d_state, self.ngroups * self.d_state],
303
+ dim=-1,
304
+ )
305
+ y = mamba_chunk_scan_combined(
306
+ rearrange(x, "b l (h p) -> b l h p", p=self.headdim),
307
+ dt,
308
+ A,
309
+ rearrange(B, "b l (g n) -> b l g n", g=self.ngroups),
310
+ rearrange(C, "b l (g n) -> b l g n", g=self.ngroups),
311
+ chunk_size=self.chunk_size,
312
+ D=(
313
+ rearrange(self.D, "(h p) -> h p", p=self.headdim)
314
+ if self.D_has_hdim
315
+ else self.D
316
+ ),
317
+ z=(
318
+ rearrange(z, "b l (h p) -> b l h p", p=self.headdim)
319
+ if not self.rmsnorm
320
+ else None
321
+ ),
322
+ dt_bias=self.dt_bias,
323
+ dt_softplus=True,
324
+ seq_idx=seq_idx,
325
+ cu_seqlens=cu_seqlens,
326
+ **dt_limit_kwargs,
327
+ return_final_states=ssm_state is not None,
328
+ return_varlen_states=cu_seqlens is not None
329
+ and inference_params is not None,
330
+ )
331
+ if ssm_state is not None:
332
+ y, last_state, *rest = y
333
+ if cu_seqlens is None:
334
+ ssm_state.copy_(last_state)
335
+ else:
336
+ varlen_states = rest[0]
337
+ ssm_state.copy_(varlen_states)
338
+ y = rearrange(y, "b l h p -> b l (h p)")
339
+ if self.rmsnorm:
340
+ y = self.norm(y, z)
341
+ if d_mlp > 0:
342
+ y = torch.cat([F.silu(z0) * x0, y], dim=-1)
343
+ if seqlen_og is not None:
344
+ y = rearrange(y, "b l d -> (b l) d")
345
+ out = self.out_proj(y)
346
+ return out
347
+
348
+ def step(self, hidden_states, conv_state, ssm_state):
349
+ dtype = hidden_states.dtype
350
+ assert (
351
+ hidden_states.shape[1] == 1
352
+ ), "Only support decoding with 1 token at a time for now"
353
+ zxbcdt = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
354
+ d_mlp = (
355
+ zxbcdt.shape[-1]
356
+ - 2 * self.d_ssm
357
+ - 2 * self.ngroups * self.d_state
358
+ - self.nheads
359
+ ) // 2
360
+ z0, x0, z, xBC, dt = torch.split(
361
+ zxbcdt,
362
+ [
363
+ d_mlp,
364
+ d_mlp,
365
+ self.d_ssm,
366
+ self.d_ssm + 2 * self.ngroups * self.d_state,
367
+ self.nheads,
368
+ ],
369
+ dim=-1,
370
+ )
371
+
372
+ # Conv step
373
+ if causal_conv1d_update is None:
374
+ conv_state.copy_(
375
+ torch.roll(conv_state, shifts=-1, dims=-1)
376
+ ) # Update state (B D W)
377
+ conv_state[:, :, -1] = xBC
378
+ xBC = torch.sum(
379
+ conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1
380
+ ) # (B D)
381
+ if self.conv1d.bias is not None:
382
+ xBC = xBC + self.conv1d.bias
383
+ xBC = self.act(xBC).to(dtype=dtype)
384
+ else:
385
+ xBC = causal_conv1d_update(
386
+ xBC,
387
+ conv_state,
388
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
389
+ self.conv1d.bias,
390
+ self.activation,
391
+ )
392
+
393
+ x, B, C = torch.split(
394
+ xBC,
395
+ [self.d_ssm, self.ngroups * self.d_state, self.ngroups * self.d_state],
396
+ dim=-1,
397
+ )
398
+ A = -torch.exp(self.A_log.float()) # (nheads,)
399
+
400
+ # SSM step
401
+ if selective_state_update is None:
402
+ assert (
403
+ self.ngroups == 1
404
+ ), "Only support ngroups=1 for this inference code path"
405
+ # Discretize A and B
406
+ dt = F.softplus(dt + self.dt_bias.to(dtype=dt.dtype)) # (batch, nheads)
407
+ dA = torch.exp(dt * A) # (batch, nheads)
408
+ x = rearrange(x, "b (h p) -> b h p", p=self.headdim)
409
+ dBx = torch.einsum("bh,bn,bhp->bhpn", dt, B, x)
410
+ ssm_state.copy_(ssm_state * rearrange(dA, "b h -> b h 1 1") + dBx)
411
+ y = torch.einsum("bhpn,bn->bhp", ssm_state.to(dtype), C)
412
+ y = y + rearrange(self.D.to(dtype), "h -> h 1") * x
413
+ y = rearrange(y, "b h p -> b (h p)")
414
+ if not self.rmsnorm:
415
+ y = y * self.act(z) # (B D)
416
+ else:
417
+ A = repeat(A, "h -> h p n", p=self.headdim, n=self.d_state).to(
418
+ dtype=torch.float32
419
+ )
420
+ dt = repeat(dt, "b h -> b h p", p=self.headdim)
421
+ dt_bias = repeat(self.dt_bias, "h -> h p", p=self.headdim)
422
+ D = repeat(self.D, "h -> h p", p=self.headdim)
423
+ B = rearrange(B, "b (g n) -> b g n", g=self.ngroups)
424
+ C = rearrange(C, "b (g n) -> b g n", g=self.ngroups)
425
+ x_reshaped = rearrange(x, "b (h p) -> b h p", p=self.headdim)
426
+ if not self.rmsnorm:
427
+ z = rearrange(z, "b (h p) -> b h p", p=self.headdim)
428
+ y = selective_state_update(
429
+ ssm_state,
430
+ x_reshaped,
431
+ dt,
432
+ A,
433
+ B,
434
+ C,
435
+ D,
436
+ z=z if not self.rmsnorm else None,
437
+ dt_bias=dt_bias,
438
+ dt_softplus=True,
439
+ )
440
+ y = rearrange(y, "b h p -> b (h p)")
441
+ if self.rmsnorm:
442
+ y = self.norm(y, z)
443
+ if d_mlp > 0:
444
+ y = torch.cat([F.silu(z0) * x0, y], dim=-1)
445
+ out = self.out_proj(y)
446
+ return out.unsqueeze(1), conv_state, ssm_state
447
+
448
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
449
+ device = self.out_proj.weight.device
450
+ conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
451
+ conv_state = torch.zeros(
452
+ batch_size,
453
+ self.d_conv,
454
+ self.conv1d.weight.shape[0],
455
+ device=device,
456
+ dtype=conv_dtype,
457
+ ).transpose(1, 2)
458
+ ssm_dtype = self.in_proj.weight.dtype if dtype is None else dtype
459
+ ssm_state = torch.zeros(
460
+ batch_size,
461
+ self.nheads,
462
+ self.headdim,
463
+ self.d_state,
464
+ device=device,
465
+ dtype=ssm_dtype,
466
+ )
467
+ return conv_state, ssm_state
468
+
469
+ def _get_states_from_cache(
470
+ self, inference_params, batch_size, initialize_states=False
471
+ ):
472
+ assert self.layer_idx is not None
473
+ if self.layer_idx not in inference_params.key_value_memory_dict:
474
+ batch_shape = (batch_size,)
475
+ conv_state = torch.zeros(
476
+ batch_size,
477
+ self.d_conv,
478
+ self.conv1d.weight.shape[0],
479
+ device=self.conv1d.weight.device,
480
+ dtype=self.conv1d.weight.dtype,
481
+ ).transpose(1, 2)
482
+ ssm_state = torch.zeros(
483
+ batch_size,
484
+ self.nheads,
485
+ self.headdim,
486
+ self.d_state,
487
+ device=self.in_proj.weight.device,
488
+ dtype=self.in_proj.weight.dtype,
489
+ )
490
+ inference_params.key_value_memory_dict[self.layer_idx] = (
491
+ conv_state,
492
+ ssm_state,
493
+ )
494
+ else:
495
+ conv_state, ssm_state = inference_params.key_value_memory_dict[
496
+ self.layer_idx
497
+ ]
498
+ # TODO: What if batch size changes between generation, and we reuse the same states?
499
+ if initialize_states:
500
+ conv_state.zero_()
501
+ ssm_state.zero_()
502
+ return conv_state, ssm_state
build/torch210-cxx11-cu126-x86_64-linux/modules/mamba2_simple.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao, Albert Gu.
2
+
3
+ import math
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from einops import rearrange, repeat
9
+
10
+ try:
11
+ from causal_conv1d import causal_conv1d_fn
12
+ except ImportError:
13
+ causal_conv1d_fn = None
14
+
15
+ try:
16
+ from ..ops.triton.layernorm_gated import RMSNorm as RMSNormGated, LayerNorm
17
+ except ImportError:
18
+ RMSNormGated, LayerNorm = None, None
19
+
20
+ from ..ops.triton.ssd_combined import mamba_chunk_scan_combined
21
+ from ..ops.triton.ssd_combined import mamba_split_conv1d_scan_combined
22
+
23
+
24
+ class Mamba2Simple(nn.Module):
25
+ def __init__(
26
+ self,
27
+ d_model,
28
+ d_state=64,
29
+ d_conv=4,
30
+ conv_init=None,
31
+ expand=2,
32
+ headdim=128,
33
+ ngroups=1,
34
+ A_init_range=(1, 16),
35
+ dt_min=0.001,
36
+ dt_max=0.1,
37
+ dt_init_floor=1e-4,
38
+ dt_limit=(0.0, float("inf")),
39
+ learnable_init_states=False,
40
+ activation="swish",
41
+ bias=False,
42
+ conv_bias=True,
43
+ # Fused kernel and sharding options
44
+ chunk_size=256,
45
+ use_mem_eff_path=True,
46
+ layer_idx=None, # Absorb kwarg for general module
47
+ device=None,
48
+ dtype=None,
49
+ ):
50
+ factory_kwargs = {"device": device, "dtype": dtype}
51
+ super().__init__()
52
+ self.d_model = d_model
53
+ self.d_state = d_state
54
+ self.d_conv = d_conv
55
+ self.conv_init = conv_init
56
+ self.expand = expand
57
+ self.d_inner = self.expand * self.d_model
58
+ self.headdim = headdim
59
+ self.ngroups = ngroups
60
+ assert self.d_inner % self.headdim == 0
61
+ self.nheads = self.d_inner // self.headdim
62
+ self.dt_limit = dt_limit
63
+ self.learnable_init_states = learnable_init_states
64
+ self.activation = activation
65
+ self.chunk_size = chunk_size
66
+ self.use_mem_eff_path = use_mem_eff_path
67
+ self.layer_idx = layer_idx
68
+
69
+ # Order: [z, x, B, C, dt]
70
+ d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
71
+ self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=bias, **factory_kwargs)
72
+
73
+ conv_dim = self.d_inner + 2 * self.ngroups * self.d_state
74
+ self.conv1d = nn.Conv1d(
75
+ in_channels=conv_dim,
76
+ out_channels=conv_dim,
77
+ bias=conv_bias,
78
+ kernel_size=d_conv,
79
+ groups=conv_dim,
80
+ padding=d_conv - 1,
81
+ **factory_kwargs,
82
+ )
83
+ if self.conv_init is not None:
84
+ nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init)
85
+ # self.conv1d.weight._no_weight_decay = True
86
+
87
+ if self.learnable_init_states:
88
+ self.init_states = nn.Parameter(
89
+ torch.zeros(self.nheads, self.headdim, self.d_state, **factory_kwargs)
90
+ )
91
+ self.init_states._no_weight_decay = True
92
+
93
+ self.act = nn.SiLU()
94
+
95
+ # Initialize log dt bias
96
+ dt = torch.exp(
97
+ torch.rand(self.nheads, **factory_kwargs)
98
+ * (math.log(dt_max) - math.log(dt_min))
99
+ + math.log(dt_min)
100
+ )
101
+ dt = torch.clamp(dt, min=dt_init_floor)
102
+ # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
103
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
104
+ self.dt_bias = nn.Parameter(inv_dt)
105
+ # Just to be explicit. Without this we already don't put wd on dt_bias because of the check
106
+ # name.endswith("bias") in param_grouping.py
107
+ self.dt_bias._no_weight_decay = True
108
+
109
+ # A parameter
110
+ assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0]
111
+ A = torch.empty(self.nheads, dtype=torch.float32, device=device).uniform_(
112
+ *A_init_range
113
+ )
114
+ A_log = torch.log(A).to(dtype=dtype)
115
+ self.A_log = nn.Parameter(A_log)
116
+ # self.register_buffer("A_log", torch.zeros(self.nheads, dtype=torch.float32, device=device), persistent=True)
117
+ self.A_log._no_weight_decay = True
118
+
119
+ # D "skip" parameter
120
+ self.D = nn.Parameter(torch.ones(self.nheads, device=device))
121
+ self.D._no_weight_decay = True
122
+
123
+ # Extra normalization layer right before output projection
124
+ assert RMSNormGated is not None
125
+ self.norm = RMSNormGated(
126
+ self.d_inner, eps=1e-5, norm_before_gate=False, **factory_kwargs
127
+ )
128
+
129
+ self.out_proj = nn.Linear(
130
+ self.d_inner, self.d_model, bias=bias, **factory_kwargs
131
+ )
132
+
133
+ def forward(self, u, seq_idx=None):
134
+ """
135
+ u: (B, L, D)
136
+ Returns: same shape as u
137
+ """
138
+ batch, seqlen, dim = u.shape
139
+
140
+ zxbcdt = self.in_proj(u) # (B, L, d_in_proj)
141
+ A = -torch.exp(self.A_log) # (nheads) or (d_inner, d_state)
142
+ initial_states = (
143
+ repeat(self.init_states, "... -> b ...", b=batch)
144
+ if self.learnable_init_states
145
+ else None
146
+ )
147
+ dt_limit_kwargs = (
148
+ {} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit)
149
+ )
150
+
151
+ if self.use_mem_eff_path:
152
+ # Fully fused path
153
+ out = mamba_split_conv1d_scan_combined(
154
+ zxbcdt,
155
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
156
+ self.conv1d.bias,
157
+ self.dt_bias,
158
+ A,
159
+ D=self.D,
160
+ chunk_size=self.chunk_size,
161
+ seq_idx=seq_idx,
162
+ activation=self.activation,
163
+ rmsnorm_weight=self.norm.weight,
164
+ rmsnorm_eps=self.norm.eps,
165
+ outproj_weight=self.out_proj.weight,
166
+ outproj_bias=self.out_proj.bias,
167
+ headdim=self.headdim,
168
+ ngroups=self.ngroups,
169
+ norm_before_gate=False,
170
+ initial_states=initial_states,
171
+ **dt_limit_kwargs,
172
+ )
173
+ else:
174
+ z, xBC, dt = torch.split(
175
+ zxbcdt,
176
+ [
177
+ self.d_inner,
178
+ self.d_inner + 2 * self.ngroups * self.d_state,
179
+ self.nheads,
180
+ ],
181
+ dim=-1,
182
+ )
183
+ dt = F.softplus(dt + self.dt_bias) # (B, L, nheads)
184
+ assert self.activation in ["silu", "swish"]
185
+
186
+ # 1D Convolution
187
+ if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]:
188
+ xBC = self.act(
189
+ self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)
190
+ ) # (B, L, self.d_inner + 2 * ngroups * d_state)
191
+ xBC = xBC[:, :seqlen, :]
192
+ else:
193
+ xBC = causal_conv1d_fn(
194
+ x=xBC.transpose(1, 2),
195
+ weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
196
+ bias=self.conv1d.bias,
197
+ activation=self.activation,
198
+ ).transpose(1, 2)
199
+
200
+ # Split into 3 main branches: X, B, C
201
+ # These correspond to V, K, Q respectively in the SSM/attention duality
202
+ x, B, C = torch.split(
203
+ xBC,
204
+ [
205
+ self.d_inner,
206
+ self.ngroups * self.d_state,
207
+ self.ngroups * self.d_state,
208
+ ],
209
+ dim=-1,
210
+ )
211
+ y = mamba_chunk_scan_combined(
212
+ rearrange(x, "b l (h p) -> b l h p", p=self.headdim),
213
+ dt,
214
+ A,
215
+ rearrange(B, "b l (g n) -> b l g n", g=self.ngroups),
216
+ rearrange(C, "b l (g n) -> b l g n", g=self.ngroups),
217
+ chunk_size=self.chunk_size,
218
+ D=self.D,
219
+ z=None,
220
+ seq_idx=seq_idx,
221
+ initial_states=initial_states,
222
+ **dt_limit_kwargs,
223
+ )
224
+ y = rearrange(y, "b l h p -> b l (h p)")
225
+
226
+ # Multiply "gate" branch and apply extra normalization layer
227
+ y = self.norm(y, z)
228
+ out = self.out_proj(y)
229
+ return out
build/torch210-cxx11-cu126-x86_64-linux/modules/mamba_simple.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao, Albert Gu.
2
+
3
+ import math
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torch import Tensor
10
+
11
+ from einops import rearrange, repeat
12
+
13
+ from ..ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
14
+
15
+ try:
16
+ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
17
+ except ImportError:
18
+ causal_conv1d_fn, causal_conv1d_update = None, None
19
+
20
+ try:
21
+ from ..ops.triton.selective_state_update import selective_state_update
22
+ except ImportError:
23
+ selective_state_update = None
24
+
25
+ try:
26
+ from ..ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
27
+ except ImportError:
28
+ RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
29
+
30
+
31
+ class Mamba(nn.Module):
32
+ def __init__(
33
+ self,
34
+ d_model,
35
+ d_state=16,
36
+ d_conv=4,
37
+ expand=2,
38
+ dt_rank="auto",
39
+ dt_min=0.001,
40
+ dt_max=0.1,
41
+ dt_init="random",
42
+ dt_scale=1.0,
43
+ dt_init_floor=1e-4,
44
+ conv_bias=True,
45
+ bias=False,
46
+ use_fast_path=True, # Fused kernel options
47
+ layer_idx=None,
48
+ device=None,
49
+ dtype=None,
50
+ ):
51
+ factory_kwargs = {"device": device, "dtype": dtype}
52
+ super().__init__()
53
+ self.d_model = d_model
54
+ self.d_state = d_state
55
+ self.d_conv = d_conv
56
+ self.expand = expand
57
+ self.d_inner = int(self.expand * self.d_model)
58
+ self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
59
+ self.use_fast_path = use_fast_path
60
+ self.layer_idx = layer_idx
61
+
62
+ self.in_proj = nn.Linear(
63
+ self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs
64
+ )
65
+
66
+ self.conv1d = nn.Conv1d(
67
+ in_channels=self.d_inner,
68
+ out_channels=self.d_inner,
69
+ bias=conv_bias,
70
+ kernel_size=d_conv,
71
+ groups=self.d_inner,
72
+ padding=d_conv - 1,
73
+ **factory_kwargs,
74
+ )
75
+
76
+ self.activation = "silu"
77
+ self.act = nn.SiLU()
78
+
79
+ self.x_proj = nn.Linear(
80
+ self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
81
+ )
82
+ self.dt_proj = nn.Linear(
83
+ self.dt_rank, self.d_inner, bias=True, **factory_kwargs
84
+ )
85
+
86
+ # Initialize special dt projection to preserve variance at initialization
87
+ dt_init_std = self.dt_rank**-0.5 * dt_scale
88
+ if dt_init == "constant":
89
+ nn.init.constant_(self.dt_proj.weight, dt_init_std)
90
+ elif dt_init == "random":
91
+ nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
92
+ else:
93
+ raise NotImplementedError
94
+
95
+ # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
96
+ dt = torch.exp(
97
+ torch.rand(self.d_inner, **factory_kwargs)
98
+ * (math.log(dt_max) - math.log(dt_min))
99
+ + math.log(dt_min)
100
+ ).clamp(min=dt_init_floor)
101
+ # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
102
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
103
+ with torch.no_grad():
104
+ self.dt_proj.bias.copy_(inv_dt)
105
+ # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
106
+ self.dt_proj.bias._no_reinit = True
107
+
108
+ # S4D real initialization
109
+ A = repeat(
110
+ torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
111
+ "n -> d n",
112
+ d=self.d_inner,
113
+ ).contiguous()
114
+ A_log = torch.log(A) # Keep A_log in fp32
115
+ self.A_log = nn.Parameter(A_log)
116
+ self.A_log._no_weight_decay = True
117
+
118
+ # D "skip" parameter
119
+ self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32
120
+ self.D._no_weight_decay = True
121
+
122
+ self.out_proj = nn.Linear(
123
+ self.d_inner, self.d_model, bias=bias, **factory_kwargs
124
+ )
125
+
126
+ def forward(self, hidden_states, inference_params=None):
127
+ """
128
+ hidden_states: (B, L, D)
129
+ Returns: same shape as hidden_states
130
+ """
131
+ batch, seqlen, dim = hidden_states.shape
132
+
133
+ conv_state, ssm_state = None, None
134
+ if inference_params is not None:
135
+ conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
136
+ if inference_params.seqlen_offset > 0:
137
+ # The states are updated inplace
138
+ out, _, _ = self.step(hidden_states, conv_state, ssm_state)
139
+ return out
140
+
141
+ # We do matmul and transpose BLH -> HBL at the same time
142
+ xz = rearrange(
143
+ self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"),
144
+ "d (b l) -> b d l",
145
+ l=seqlen,
146
+ )
147
+ if self.in_proj.bias is not None:
148
+ xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1")
149
+
150
+ A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
151
+ # In the backward pass we write dx and dz next to each other to avoid torch.cat
152
+ if (
153
+ self.use_fast_path
154
+ and causal_conv1d_fn is not None
155
+ and inference_params is None
156
+ ): # Doesn't support outputting the states
157
+ out = mamba_inner_fn(
158
+ xz,
159
+ self.conv1d.weight,
160
+ self.conv1d.bias,
161
+ self.x_proj.weight,
162
+ self.dt_proj.weight,
163
+ self.out_proj.weight,
164
+ self.out_proj.bias,
165
+ A,
166
+ None, # input-dependent B
167
+ None, # input-dependent C
168
+ self.D.float(),
169
+ delta_bias=self.dt_proj.bias.float(),
170
+ delta_softplus=True,
171
+ )
172
+ else:
173
+ x, z = xz.chunk(2, dim=1)
174
+ # Compute short convolution
175
+ if conv_state is not None:
176
+ # If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
177
+ # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
178
+ conv_state.copy_(
179
+ F.pad(x, (self.d_conv - x.shape[-1], 0))
180
+ ) # Update state (B D W)
181
+ if causal_conv1d_fn is None:
182
+ x = self.act(self.conv1d(x)[..., :seqlen])
183
+ else:
184
+ assert self.activation in ["silu", "swish"]
185
+ x = causal_conv1d_fn(
186
+ x=x,
187
+ weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
188
+ bias=self.conv1d.bias,
189
+ activation=self.activation,
190
+ )
191
+
192
+ # We're careful here about the layout, to avoid extra transposes.
193
+ # We want dt to have d as the slowest moving dimension
194
+ # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
195
+ x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d)
196
+ dt, B, C = torch.split(
197
+ x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1
198
+ )
199
+ dt = self.dt_proj.weight @ dt.t()
200
+ dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
201
+ B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
202
+ C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
203
+ assert self.activation in ["silu", "swish"]
204
+ y = selective_scan_fn(
205
+ x,
206
+ dt,
207
+ A,
208
+ B,
209
+ C,
210
+ self.D.float(),
211
+ z=z,
212
+ delta_bias=self.dt_proj.bias.float(),
213
+ delta_softplus=True,
214
+ return_last_state=ssm_state is not None,
215
+ )
216
+ if ssm_state is not None:
217
+ y, last_state = y
218
+ ssm_state.copy_(last_state)
219
+ y = rearrange(y, "b d l -> b l d")
220
+ out = self.out_proj(y)
221
+ return out
222
+
223
+ def step(self, hidden_states, conv_state, ssm_state):
224
+ dtype = hidden_states.dtype
225
+ assert (
226
+ hidden_states.shape[1] == 1
227
+ ), "Only support decoding with 1 token at a time for now"
228
+ xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
229
+ x, z = xz.chunk(2, dim=-1) # (B D)
230
+
231
+ # Conv step
232
+ if causal_conv1d_update is None:
233
+ conv_state.copy_(
234
+ torch.roll(conv_state, shifts=-1, dims=-1)
235
+ ) # Update state (B D W)
236
+ conv_state[:, :, -1] = x
237
+ x = torch.sum(
238
+ conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1
239
+ ) # (B D)
240
+ if self.conv1d.bias is not None:
241
+ x = x + self.conv1d.bias
242
+ x = self.act(x).to(dtype=dtype)
243
+ else:
244
+ x = causal_conv1d_update(
245
+ x,
246
+ conv_state,
247
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
248
+ self.conv1d.bias,
249
+ self.activation,
250
+ )
251
+
252
+ x_db = self.x_proj(x) # (B dt_rank+2*d_state)
253
+ dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
254
+ # Don't add dt_bias here
255
+ dt = F.linear(dt, self.dt_proj.weight) # (B d_inner)
256
+ A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
257
+
258
+ # SSM step
259
+ if selective_state_update is None:
260
+ # Discretize A and B
261
+ dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype))
262
+ dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A))
263
+ dB = torch.einsum("bd,bn->bdn", dt, B)
264
+ ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB)
265
+ y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C)
266
+ y = y + self.D.to(dtype) * x
267
+ y = y * self.act(z) # (B D)
268
+ else:
269
+ y = selective_state_update(
270
+ ssm_state,
271
+ x,
272
+ dt,
273
+ A,
274
+ B,
275
+ C,
276
+ self.D,
277
+ z=z,
278
+ dt_bias=self.dt_proj.bias,
279
+ dt_softplus=True,
280
+ )
281
+
282
+ out = self.out_proj(y)
283
+ return out.unsqueeze(1), conv_state, ssm_state
284
+
285
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
286
+ device = self.out_proj.weight.device
287
+ conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
288
+ conv_state = torch.zeros(
289
+ batch_size,
290
+ self.d_model * self.expand,
291
+ self.d_conv,
292
+ device=device,
293
+ dtype=conv_dtype,
294
+ )
295
+ ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype
296
+ # ssm_dtype = torch.float32
297
+ ssm_state = torch.zeros(
298
+ batch_size,
299
+ self.d_model * self.expand,
300
+ self.d_state,
301
+ device=device,
302
+ dtype=ssm_dtype,
303
+ )
304
+ return conv_state, ssm_state
305
+
306
+ def _get_states_from_cache(
307
+ self, inference_params, batch_size, initialize_states=False
308
+ ):
309
+ assert self.layer_idx is not None
310
+ if self.layer_idx not in inference_params.key_value_memory_dict:
311
+ batch_shape = (batch_size,)
312
+ conv_state = torch.zeros(
313
+ batch_size,
314
+ self.d_model * self.expand,
315
+ self.d_conv,
316
+ device=self.conv1d.weight.device,
317
+ dtype=self.conv1d.weight.dtype,
318
+ )
319
+ ssm_state = torch.zeros(
320
+ batch_size,
321
+ self.d_model * self.expand,
322
+ self.d_state,
323
+ device=self.dt_proj.weight.device,
324
+ dtype=self.dt_proj.weight.dtype,
325
+ # dtype=torch.float32,
326
+ )
327
+ inference_params.key_value_memory_dict[self.layer_idx] = (
328
+ conv_state,
329
+ ssm_state,
330
+ )
331
+ else:
332
+ conv_state, ssm_state = inference_params.key_value_memory_dict[
333
+ self.layer_idx
334
+ ]
335
+ # TODO: What if batch size changes between generation, and we reuse the same states?
336
+ if initialize_states:
337
+ conv_state.zero_()
338
+ ssm_state.zero_()
339
+ return conv_state, ssm_state
build/torch210-cxx11-cu126-x86_64-linux/modules/mha.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao, Albert Gu.
2
+
3
+ import math
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from einops import rearrange
9
+
10
+ try:
11
+ from flash_attn import flash_attn_with_kvcache
12
+ except ImportError:
13
+ flash_attn_with_kvcache = None
14
+
15
+ try:
16
+ from flash_attn.layers.rotary import RotaryEmbedding
17
+ except ImportError:
18
+ RotaryEmbedding = None
19
+
20
+ try:
21
+ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
22
+ except ImportError:
23
+ causal_conv1d_fn, causal_conv1d_update = None, None
24
+
25
+
26
+ def _update_kv_cache(kv, inference_params, layer_idx):
27
+ """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
28
+ # Pre-allocate memory for key-values for inference.
29
+ num_heads, head_dim = kv.shape[-2:]
30
+ assert layer_idx in inference_params.key_value_memory_dict
31
+ kv_cache, _ = inference_params.key_value_memory_dict[layer_idx]
32
+ # Adjust key and value for inference
33
+ batch_start = inference_params.batch_size_offset
34
+ batch_end = batch_start + kv.shape[0]
35
+ sequence_start = inference_params.seqlen_offset
36
+ sequence_end = sequence_start + kv.shape[1]
37
+ assert batch_end <= kv_cache.shape[0]
38
+ assert sequence_end <= kv_cache.shape[1]
39
+ assert kv_cache is not None
40
+ kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
41
+ return kv_cache[batch_start:batch_end, :sequence_end, ...]
42
+
43
+
44
+ class MHA(nn.Module):
45
+ """Multi-head self-attention and cross-attention"""
46
+
47
+ def __init__(
48
+ self,
49
+ embed_dim,
50
+ num_heads,
51
+ num_heads_kv=None,
52
+ head_dim=None, # If None, use embed_dim // num_heads
53
+ mlp_dim=0,
54
+ qkv_proj_bias=True,
55
+ out_proj_bias=True,
56
+ softmax_scale=None,
57
+ causal=False,
58
+ layer_idx=None,
59
+ d_conv=0,
60
+ rotary_emb_dim=0,
61
+ rotary_emb_base=10000.0,
62
+ rotary_emb_interleaved=False,
63
+ device=None,
64
+ dtype=None,
65
+ ) -> None:
66
+ """
67
+ num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
68
+ return_residual: whether to return the input x along with the output. This is for
69
+ performance reason: for post-norm architecture, returning the input allows us
70
+ to fuse the backward of nn.Linear with the residual connection.
71
+ """
72
+ factory_kwargs = {"device": device, "dtype": dtype}
73
+ super().__init__()
74
+ self.embed_dim = embed_dim
75
+ self.layer_idx = layer_idx
76
+ self.d_conv = d_conv
77
+ self.rotary_emb_dim = rotary_emb_dim
78
+ self.softmax_scale = softmax_scale
79
+ self.causal = causal
80
+
81
+ self.num_heads = num_heads
82
+ self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
83
+ assert (
84
+ self.num_heads % self.num_heads_kv == 0
85
+ ), "num_heads must be divisible by num_heads_kv"
86
+ if head_dim is None:
87
+ assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
88
+ self.head_dim = head_dim if head_dim is not None else self.embed_dim // num_heads
89
+ self.mlp_dim = math.ceil(mlp_dim / 256) * 256
90
+ qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
91
+ out_dim = self.head_dim * self.num_heads
92
+
93
+ if self.rotary_emb_dim > 0:
94
+ assert RotaryEmbedding is not None, "rotary requires flash_attn to be installed"
95
+ self.rotary_emb = RotaryEmbedding(
96
+ self.rotary_emb_dim,
97
+ base=rotary_emb_base,
98
+ interleaved=rotary_emb_interleaved,
99
+ device=device,
100
+ )
101
+
102
+ self.in_proj = nn.Linear(embed_dim, qkv_dim + self.mlp_dim, bias=qkv_proj_bias, **factory_kwargs)
103
+ if self.d_conv > 0:
104
+ self.conv1d = nn.Conv1d(
105
+ qkv_dim, qkv_dim, kernel_size=self.d_conv, padding=self.d_conv - 1, groups=qkv_dim,
106
+ **factory_kwargs
107
+ )
108
+ self.out_proj = nn.Linear(out_dim + self.mlp_dim // 2, embed_dim, bias=out_proj_bias, **factory_kwargs)
109
+
110
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
111
+ dtype = self.out_proj.weight.dtype if dtype is None else dtype
112
+ device = self.out_proj.weight.device
113
+ if self.d_conv > 0:
114
+ conv_state = torch.zeros(
115
+ batch_size, self.conv1d.weight.shape[0], self.d_conv, device=device, dtype=dtype
116
+ )
117
+ else:
118
+ conv_state = None
119
+ kv_cache = torch.empty(
120
+ batch_size, max_seqlen, 2, self.num_heads_kv, self.head_dim, dtype=dtype, device=device,
121
+ )
122
+ return kv_cache, conv_state
123
+
124
+ def _update_kv_cache(self, kv, inference_params):
125
+ """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
126
+ assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
127
+ return _update_kv_cache(kv, inference_params, self.layer_idx)
128
+
129
+ def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
130
+ """
131
+ Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
132
+ q: (batch_size, seqlen_q, nheads, head_dim)
133
+ kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
134
+ """
135
+ assert inference_params is not None and inference_params.seqlen_offset > 0
136
+ if self.rotary_emb_dim > 0:
137
+ self.rotary_emb._update_cos_sin_cache(
138
+ inference_params.max_seqlen, device=q.device, dtype=q.dtype
139
+ )
140
+ rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached
141
+ else:
142
+ rotary_cos, rotary_sin = None, None
143
+ batch = q.shape[0]
144
+ kv_cache, _ = inference_params.key_value_memory_dict[self.layer_idx]
145
+ kv_cache = kv_cache[:batch]
146
+ cache_seqlens = (
147
+ inference_params.lengths_per_sample[:batch]
148
+ if inference_params.lengths_per_sample is not None
149
+ else inference_params.seqlen_offset
150
+ )
151
+ assert flash_attn_with_kvcache is not None, "flash_attn must be installed"
152
+ context = flash_attn_with_kvcache(
153
+ q,
154
+ kv_cache[:, :, 0],
155
+ kv_cache[:, :, 1],
156
+ kv[:, :, 0],
157
+ kv[:, :, 1],
158
+ rotary_cos=rotary_cos,
159
+ rotary_sin=rotary_sin,
160
+ cache_seqlens=cache_seqlens,
161
+ softmax_scale=self.softmax_scale,
162
+ causal=self.causal,
163
+ rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
164
+ )
165
+ return context
166
+
167
+ def _update_kvcache_attention(self, q, kv, inference_params):
168
+ """Write kv to inference_params, then do attention"""
169
+ if (
170
+ inference_params.seqlen_offset == 0
171
+ or flash_attn_with_kvcache is None
172
+ ):
173
+ # TODO: this only uses seqlen_offset and not lengths_per_sample.
174
+ kv = self._update_kv_cache(kv, inference_params)
175
+ k, v = kv.unbind(dim=-3)
176
+ k = torch.repeat_interleave(k, dim=2, repeats=self.num_heads // self.num_heads_kv)
177
+ v = torch.repeat_interleave(v, dim=2, repeats=self.num_heads // self.num_heads_kv)
178
+ return F.scaled_dot_product_attention(
179
+ q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=self.causal, scale=self.softmax_scale
180
+ ).transpose(1, 2)
181
+ else:
182
+ batch = q.shape[0]
183
+ kv_cache, _ = inference_params.key_value_memory_dict[self.layer_idx]
184
+ kv_cache = kv_cache[:batch]
185
+ cache_seqlens = (
186
+ inference_params.lengths_per_sample[:batch]
187
+ if inference_params.lengths_per_sample is not None
188
+ else inference_params.seqlen_offset
189
+ )
190
+ return flash_attn_with_kvcache(
191
+ q,
192
+ kv_cache[:, :, 0],
193
+ kv_cache[:, :, 1],
194
+ kv[:, :, 0],
195
+ kv[:, :, 1],
196
+ cache_seqlens=cache_seqlens,
197
+ softmax_scale=self.softmax_scale,
198
+ causal=self.causal,
199
+ )
200
+
201
+ def forward(self, x, inference_params=None):
202
+ """
203
+ Arguments:
204
+ x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
205
+ cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
206
+ is the is the sum of the sequence lengths in the batch.
207
+ inference_params: for generation. Adapted from Megatron-LM (and Apex)
208
+ https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
209
+ """
210
+ if inference_params is not None and self.layer_idx not in inference_params.key_value_memory_dict:
211
+ inference_params.key_value_memory_dict[self.layer_idx] = self.allocate_inference_cache(
212
+ x.shape[0], inference_params.max_seqlen, dtype=x.dtype
213
+ )
214
+ seqlen_offset = (
215
+ 0
216
+ if inference_params is None
217
+ else (
218
+ inference_params.lengths_per_sample
219
+ if inference_params.lengths_per_sample is not None
220
+ else inference_params.seqlen_offset
221
+ )
222
+ )
223
+ rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None
224
+ qkv = self.in_proj(x)
225
+ if self.mlp_dim > 0:
226
+ qkv, x_mlp = qkv.split([qkv.shape[-1] - self.mlp_dim, self.mlp_dim], dim=-1)
227
+ x_mlp_up, x_mlp_gate = x_mlp.chunk(2, dim=-1)
228
+ x_mlp = x_mlp_up * F.silu(x_mlp_gate)
229
+ if self.d_conv > 0:
230
+ # The inference code for conv1d is pretty messy, should clean it up
231
+ if (inference_params is None or inference_params.seqlen_offset == 0):
232
+ if causal_conv1d_fn is None:
233
+ qkv = rearrange(
234
+ self.conv1d(rearrange(qkv, "b s d -> b d s"))[..., :-(self.d_conv - 1)], "b d s -> b s d"
235
+ ).contiguous()
236
+ else:
237
+ qkv = causal_conv1d_fn(
238
+ qkv.transpose(1, 2),
239
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
240
+ self.conv1d.bias
241
+ ).transpose(1, 2)
242
+ if inference_params is not None:
243
+ _, conv_state = inference_params.key_value_memory_dict[self.layer_idx]
244
+ # If we just take qkv[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
245
+ # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
246
+ qkv_t = rearrange(qkv, "b l d -> b d l")
247
+ conv_state.copy_(F.pad(qkv_t, (self.d_conv - qkv_t.shape[-1], 0))) # Update state (B D W)
248
+ else:
249
+ _, conv_state = inference_params.key_value_memory_dict[self.layer_idx]
250
+ assert qkv.shape[1] == 1, "Only support decoding with 1 token at a time for now"
251
+ qkv = qkv.squeeze(1)
252
+ # Conv step
253
+ if causal_conv1d_update is None:
254
+ conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
255
+ conv_state[:, :, -1] = qkv
256
+ qkv = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D)
257
+ if self.conv1d.bias is not None:
258
+ qkv = qkv + self.conv1d.bias
259
+ else:
260
+ qkv = causal_conv1d_update(
261
+ qkv,
262
+ conv_state,
263
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
264
+ self.conv1d.bias
265
+ )
266
+ qkv = qkv.unsqueeze(1)
267
+ q, kv = qkv.split([self.num_heads * self.head_dim, self.num_heads_kv * 2 * self.head_dim], dim=-1)
268
+ q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
269
+ kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
270
+ if (
271
+ inference_params is None
272
+ or inference_params.seqlen_offset == 0
273
+ or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
274
+ ):
275
+ if self.rotary_emb_dim > 0:
276
+ q, kv = self.rotary_emb(
277
+ q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
278
+ )
279
+ if inference_params is None:
280
+ k, v = kv.unbind(dim=-3)
281
+ k = torch.repeat_interleave(k, dim=2, repeats=self.num_heads // self.num_heads_kv)
282
+ v = torch.repeat_interleave(v, dim=2, repeats=self.num_heads // self.num_heads_kv)
283
+ context = F.scaled_dot_product_attention(
284
+ q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=self.causal, scale=self.softmax_scale
285
+ ).transpose(1, 2)
286
+ else:
287
+ context = self._update_kvcache_attention(q, kv, inference_params)
288
+ else:
289
+ context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
290
+ context = rearrange(context, "... h d -> ... (h d)")
291
+ if self.mlp_dim > 0:
292
+ context = torch.cat([context, x_mlp], dim=-1)
293
+ out = self.out_proj(context)
294
+ return out
build/torch210-cxx11-cu126-x86_64-linux/modules/mlp.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao, Albert Gu.
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+
5
+
6
+ class GatedMLP(nn.Module):
7
+ def __init__(
8
+ self,
9
+ in_features,
10
+ hidden_features=None,
11
+ out_features=None,
12
+ activation=F.silu,
13
+ bias=False,
14
+ multiple_of=128,
15
+ device=None,
16
+ dtype=None,
17
+ ):
18
+ factory_kwargs = {"device": device, "dtype": dtype}
19
+ super().__init__()
20
+ out_features = out_features if out_features is not None else in_features
21
+ hidden_features = (
22
+ hidden_features if hidden_features is not None else int(8 * in_features / 3)
23
+ )
24
+ hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
25
+ self.fc1 = nn.Linear(in_features, 2 * hidden_features, bias=bias, **factory_kwargs)
26
+ self.activation = activation
27
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, **factory_kwargs)
28
+
29
+ def forward(self, x):
30
+ y = self.fc1(x)
31
+ y, gate = y.chunk(2, dim=-1)
32
+ y = y * self.activation(gate)
33
+ y = self.fc2(y)
34
+ return y
build/torch210-cxx11-cu126-x86_64-linux/modules/ssd_minimal.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Albert Gu and Tri Dao.
2
+ """Minimal implementation of SSD.
3
+
4
+ This is the same as Listing 1 from the paper.
5
+ """
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from einops import rearrange, repeat
10
+
11
+ from ..ops.triton.ssd_combined import mamba_chunk_scan_combined
12
+
13
+
14
+ def segsum_unstable(x):
15
+ """Naive segment sum calculation."""
16
+ T = x.size(-1)
17
+ x_cumsum = torch.cumsum(x, dim=-1)
18
+ x_segsum = x_cumsum[..., :, None] - x_cumsum[..., None, :]
19
+ mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0)
20
+ x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
21
+ return x_segsum
22
+
23
+
24
+ def segsum(x):
25
+ """More stable segment sum calculation."""
26
+ T = x.size(-1)
27
+ x = repeat(x, "... d -> ... d e", e=T)
28
+ mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=-1)
29
+ x = x.masked_fill(~mask, 0)
30
+ x_segsum = torch.cumsum(x, dim=-2)
31
+ mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0)
32
+ x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
33
+ return x_segsum
34
+
35
+
36
+ def ssd_minimal_discrete(X, A, B, C, block_len, initial_states=None):
37
+ """
38
+ Arguments:
39
+ X: (batch, length, n_heads, d_head)
40
+ A: (batch, length, n_heads)
41
+ B: (batch, length, n_heads, d_state)
42
+ C: (batch, length, n_heads, d_state)
43
+ Return:
44
+ Y: (batch, length, n_heads, d_head)
45
+ """
46
+ assert X.dtype == A.dtype == B.dtype == C.dtype
47
+ assert X.shape[1] % block_len == 0
48
+
49
+ # Rearrange into blocks/chunks
50
+ X, A, B, C = [
51
+ rearrange(x, "b (c l) ... -> b c l ...", l=block_len) for x in (X, A, B, C)
52
+ ]
53
+
54
+ A = rearrange(A, "b c l h -> b h c l")
55
+ A_cumsum = torch.cumsum(A, dim=-1)
56
+
57
+ # 1. Compute the output for each intra-chunk (diagonal blocks)
58
+ L = torch.exp(segsum(A))
59
+ Y_diag = torch.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C, B, L, X)
60
+
61
+ # 2. Compute the state for each intra-chunk
62
+ # (right term of low-rank factorization of off-diagonal blocks; B terms)
63
+ decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum))
64
+ states = torch.einsum("bclhn,bhcl,bclhp->bchpn", B, decay_states, X)
65
+
66
+ # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
67
+ # (middle term of factorization of off-diag blocks; A terms)
68
+ if initial_states is None:
69
+ initial_states = torch.zeros_like(states[:, :1])
70
+ states = torch.cat([initial_states, states], dim=1)
71
+ decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0))))
72
+ new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states)
73
+ states, final_state = new_states[:, :-1], new_states[:, -1]
74
+
75
+ # 4. Compute state -> output conversion per chunk
76
+ # (left term of low-rank factorization of off-diagonal blocks; C terms)
77
+ state_decay_out = torch.exp(A_cumsum)
78
+ Y_off = torch.einsum("bclhn,bchpn,bhcl->bclhp", C, states, state_decay_out)
79
+
80
+ # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
81
+ Y = rearrange(Y_diag + Y_off, "b c l h p -> b (c l) h p")
82
+ return Y, final_state
83
+
84
+
85
+ # Simple test
86
+ def test_correctness():
87
+ torch.manual_seed(42)
88
+
89
+ ## Dimensions
90
+ # Denoted (B, T, Q, D, P) in the paper
91
+ batch, seqlen, chunk_size, dim, headdim = 1, 2048, 64, 2048, 64
92
+ nheads = dim // headdim # (H) in the paper
93
+ ngroups = 1 # (G) in the paper
94
+ dstate = 64 # (N) in the paper
95
+ dtype = torch.float32
96
+ device = "cuda"
97
+
98
+ x = torch.randn(batch, seqlen, nheads, headdim, dtype=dtype, device=device)
99
+ dt = F.softplus(
100
+ torch.randn(batch, seqlen, nheads, dtype=torch.float32, device=device) - 4
101
+ ).requires_grad_()
102
+ A = (
103
+ -torch.exp(torch.rand(nheads, dtype=torch.float32, device=device))
104
+ ).requires_grad_()
105
+ B = torch.randn(batch, seqlen, ngroups, dstate, dtype=dtype, device=device)
106
+ C = torch.randn(batch, seqlen, ngroups, dstate, dtype=dtype, device=device)
107
+ D = torch.randn(nheads, dtype=dtype, device=device)
108
+
109
+ # Comparing fused version and minimal version
110
+ y = mamba_chunk_scan_combined(x, dt, A, B, C, chunk_size, D=None)
111
+ y_min, _ = ssd_minimal_discrete(x * dt.unsqueeze(-1), A * dt, B, C, chunk_size)
build/torch210-cxx11-cu126-x86_64-linux/ops/__init__.py ADDED
File without changes
build/torch210-cxx11-cu126-x86_64-linux/ops/selective_scan_interface.py ADDED
@@ -0,0 +1,446 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Tri Dao, Albert Gu.
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from ..utils.torch import custom_fwd, custom_bwd
6
+
7
+ from einops import rearrange, repeat
8
+
9
+ try:
10
+ from causal_conv1d import causal_conv1d_fn
11
+ from causal_conv1d.causal_conv1d_interface import causal_conv1d_cuda
12
+ except ImportError:
13
+ causal_conv1d_fn = None
14
+ causal_conv1d_cuda = None
15
+
16
+ from .triton.layer_norm import _layer_norm_fwd
17
+
18
+ from .._ops import ops
19
+
20
+
21
+ class SelectiveScanFn(torch.autograd.Function):
22
+
23
+ @staticmethod
24
+ def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
25
+ return_last_state=False):
26
+ if u.stride(-1) != 1:
27
+ u = u.contiguous()
28
+ if delta.stride(-1) != 1:
29
+ delta = delta.contiguous()
30
+ if D is not None:
31
+ D = D.contiguous()
32
+ if B.stride(-1) != 1:
33
+ B = B.contiguous()
34
+ if C.stride(-1) != 1:
35
+ C = C.contiguous()
36
+ if z is not None and z.stride(-1) != 1:
37
+ z = z.contiguous()
38
+ if B.dim() == 3:
39
+ B = rearrange(B, "b dstate l -> b 1 dstate l")
40
+ ctx.squeeze_B = True
41
+ if C.dim() == 3:
42
+ C = rearrange(C, "b dstate l -> b 1 dstate l")
43
+ ctx.squeeze_C = True
44
+ out, x, *rest = ops.selective_scan_fwd(
45
+ u, delta, A, B, C, D, z, delta_bias, delta_softplus
46
+ )
47
+ ctx.delta_softplus = delta_softplus
48
+ ctx.has_z = z is not None
49
+ last_state = x[:, :, -1, 1::2] # (batch, dim, dstate)
50
+ if not ctx.has_z:
51
+ ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)
52
+ return out if not return_last_state else (out, last_state)
53
+ else:
54
+ ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out)
55
+ out_z = rest[0]
56
+ return out_z if not return_last_state else (out_z, last_state)
57
+
58
+ @staticmethod
59
+ def backward(ctx, dout, *args):
60
+ if not ctx.has_z:
61
+ u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
62
+ z = None
63
+ out = None
64
+ else:
65
+ u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors
66
+ if dout.stride(-1) != 1:
67
+ dout = dout.contiguous()
68
+ # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
69
+ # backward of selective_scan_cuda with the backward of chunk).
70
+ # Here we just pass in None and dz will be allocated in the C++ code.
71
+ du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = ops.selective_scan_bwd(
72
+ u,
73
+ delta,
74
+ A,
75
+ B,
76
+ C,
77
+ D,
78
+ z,
79
+ delta_bias,
80
+ dout,
81
+ x,
82
+ out,
83
+ None,
84
+ ctx.delta_softplus,
85
+ False, # option to recompute out_z, not used here
86
+ )
87
+ dz = rest[0] if ctx.has_z else None
88
+ dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB
89
+ dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC
90
+ return (du, ddelta, dA, dB, dC,
91
+ dD if D is not None else None,
92
+ dz,
93
+ ddelta_bias if delta_bias is not None else None,
94
+ None,
95
+ None)
96
+
97
+
98
+ def rms_norm_forward(
99
+ x,
100
+ weight,
101
+ bias,
102
+ eps=1e-6,
103
+ is_rms_norm=True,
104
+ ):
105
+ # x (b l) d
106
+ if x.stride(-1) != 1:
107
+ x = x.contiguous()
108
+ weight = weight.contiguous()
109
+ if bias is not None:
110
+ bias = bias.contiguous()
111
+ y = _layer_norm_fwd(
112
+ x, weight, bias, eps, None, residual_dtype=None, is_rms_norm=is_rms_norm
113
+ )[0]
114
+ # y (b l) d
115
+ return y
116
+
117
+
118
+ def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
119
+ return_last_state=False):
120
+ """if return_last_state is True, returns (out, last_state)
121
+ last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
122
+ not considered in the backward pass.
123
+ """
124
+ return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)
125
+
126
+
127
+ def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
128
+ return_last_state=False):
129
+ """
130
+ u: r(B D L)
131
+ delta: r(B D L)
132
+ A: c(D N) or r(D N)
133
+ B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
134
+ C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
135
+ D: r(D)
136
+ z: r(B D L)
137
+ delta_bias: r(D), fp32
138
+
139
+ out: r(B D L)
140
+ last_state (optional): r(B D dstate) or c(B D dstate)
141
+ """
142
+ dtype_in = u.dtype
143
+ u = u.float()
144
+ delta = delta.float()
145
+ if delta_bias is not None:
146
+ delta = delta + delta_bias[..., None].float()
147
+ if delta_softplus:
148
+ delta = F.softplus(delta)
149
+ batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
150
+ is_variable_B = B.dim() >= 3
151
+ is_variable_C = C.dim() >= 3
152
+ if A.is_complex():
153
+ if is_variable_B:
154
+ B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2))
155
+ if is_variable_C:
156
+ C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2))
157
+ else:
158
+ B = B.float()
159
+ C = C.float()
160
+ x = A.new_zeros((batch, dim, dstate))
161
+ ys = []
162
+ deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
163
+ if not is_variable_B:
164
+ deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u)
165
+ else:
166
+ if B.dim() == 3:
167
+ deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u)
168
+ else:
169
+ B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
170
+ deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u)
171
+ if is_variable_C and C.dim() == 4:
172
+ C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
173
+ last_state = None
174
+ for i in range(u.shape[2]):
175
+ x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
176
+ if not is_variable_C:
177
+ y = torch.einsum('bdn,dn->bd', x, C)
178
+ else:
179
+ if C.dim() == 3:
180
+ y = torch.einsum('bdn,bn->bd', x, C[:, :, i])
181
+ else:
182
+ y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i])
183
+ if i == u.shape[2] - 1:
184
+ last_state = x
185
+ if y.is_complex():
186
+ y = y.real * 2
187
+ ys.append(y)
188
+ y = torch.stack(ys, dim=2) # (batch dim L)
189
+ out = y if D is None else y + u * rearrange(D, "d -> d 1")
190
+ if z is not None:
191
+ out = out * F.silu(z)
192
+ out = out.to(dtype=dtype_in)
193
+ return out if not return_last_state else (out, last_state)
194
+
195
+
196
+ class MambaInnerFn(torch.autograd.Function):
197
+
198
+ @staticmethod
199
+ @custom_fwd
200
+ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
201
+ out_proj_weight, out_proj_bias,
202
+ A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
203
+ C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1, b_rms_weight=None, c_rms_weight= None, dt_rms_weight= None, b_c_dt_rms_eps=1e-6):
204
+ """
205
+ xz: (batch, dim, seqlen)
206
+ """
207
+ assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
208
+ assert checkpoint_lvl in [0, 1]
209
+ L = xz.shape[-1]
210
+ delta_rank = delta_proj_weight.shape[1]
211
+ d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
212
+ if torch.is_autocast_enabled():
213
+ x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
214
+ delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
215
+ out_proj_weight = out_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
216
+ out_proj_bias = (out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype())
217
+ if out_proj_bias is not None else None)
218
+ if xz.stride(-1) != 1:
219
+ xz = xz.contiguous()
220
+ conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
221
+ x, z = xz.chunk(2, dim=1)
222
+ conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None
223
+ conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
224
+ x, conv1d_weight, conv1d_bias, None, None, None, True
225
+ )
226
+ # We're being very careful here about the layout, to avoid extra transposes.
227
+ # We want delta to have d as the slowest moving dimension
228
+ # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
229
+ x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight) # (bl d)
230
+ delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L)
231
+ ctx.is_variable_B = B is None
232
+ ctx.is_variable_C = C is None
233
+ ctx.B_proj_bias_is_None = B_proj_bias is None
234
+ ctx.C_proj_bias_is_None = C_proj_bias is None
235
+ if B is None: # variable B
236
+ B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl dstate)
237
+ if B_proj_bias is not None:
238
+ B = B + B_proj_bias.to(dtype=B.dtype)
239
+ if not A.is_complex():
240
+ # B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
241
+ B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
242
+ else:
243
+ B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
244
+ else:
245
+ if B.stride(-1) != 1:
246
+ B = B.contiguous()
247
+ if C is None: # variable C
248
+ C = x_dbl[:, -d_state:] # (bl dstate)
249
+ if C_proj_bias is not None:
250
+ C = C + C_proj_bias.to(dtype=C.dtype)
251
+ if not A.is_complex():
252
+ # C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
253
+ C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
254
+ else:
255
+ C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
256
+ else:
257
+ if C.stride(-1) != 1:
258
+ C = C.contiguous()
259
+ if D is not None:
260
+ D = D.contiguous()
261
+
262
+ if b_rms_weight is not None:
263
+ B = rearrange(B, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
264
+ B = rms_norm_forward(B, b_rms_weight, bias=None, eps=b_c_dt_rms_eps)
265
+ B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
266
+ if c_rms_weight is not None:
267
+ C = rearrange(C, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
268
+ C = rms_norm_forward(C, c_rms_weight, bias=None, eps=b_c_dt_rms_eps)
269
+ C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
270
+ if dt_rms_weight is not None:
271
+ delta = rearrange(delta, "b d l -> (b l) d", l=L).contiguous()
272
+ delta = rms_norm_forward(delta, dt_rms_weight, bias=None, eps=b_c_dt_rms_eps)
273
+ delta = rearrange(delta, "(b l) d -> b d l", l=L).contiguous()
274
+
275
+ out, scan_intermediates, out_z = ops.selective_scan_fwd(
276
+ conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus
277
+ )
278
+ ctx.delta_softplus = delta_softplus
279
+ ctx.out_proj_bias_is_None = out_proj_bias is None
280
+ ctx.checkpoint_lvl = checkpoint_lvl
281
+ ctx.b_rms_weight = b_rms_weight
282
+ ctx.c_rms_weight = c_rms_weight
283
+ ctx.dt_rms_weight = dt_rms_weight
284
+ ctx.b_c_dt_rms_eps = b_c_dt_rms_eps
285
+ if checkpoint_lvl >= 1: # Will recompute conv1d_out and delta in the backward pass
286
+ conv1d_out, delta = None, None
287
+ ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight,
288
+ delta_proj_weight, out_proj_weight, conv1d_out, delta,
289
+ A, B, C, D, delta_bias, scan_intermediates, b_rms_weight, c_rms_weight, dt_rms_weight, out)
290
+ return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias)
291
+
292
+ @staticmethod
293
+ @custom_bwd
294
+ def backward(ctx, dout):
295
+ # dout: (batch, seqlen, dim)
296
+ assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
297
+ (xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight,
298
+ conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, b_rms_weight, c_rms_weight, dt_rms_weight, out) = ctx.saved_tensors
299
+ L = xz.shape[-1]
300
+ delta_rank = delta_proj_weight.shape[1]
301
+ d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
302
+ x, z = xz.chunk(2, dim=1)
303
+ if dout.stride(-1) != 1:
304
+ dout = dout.contiguous()
305
+ if ctx.checkpoint_lvl == 1:
306
+ conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
307
+ x, conv1d_weight, conv1d_bias, None, None, None, True
308
+ )
309
+ delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(),
310
+ "d (b l) -> b d l", l = L)
311
+ if dt_rms_weight is not None:
312
+ delta = rearrange(delta, "b d l -> (b l) d", l=L).contiguous()
313
+ delta = rms_norm_forward(delta, ctx.dt_rms_weight, None, ctx.b_c_dt_rms_eps)
314
+ delta = rearrange(delta, "(b l) d -> b d l", l=L).contiguous()
315
+ if b_rms_weight is not None:
316
+ # Recompute & RMSNorm B
317
+ B = rearrange(B, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
318
+ B = rms_norm_forward(
319
+ B, ctx.b_rms_weight, None, ctx.b_c_dt_rms_eps
320
+ )
321
+ B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
322
+ if c_rms_weight is not None:
323
+ # Recompute & RMSNorm C
324
+ C = rearrange(C, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
325
+ C = rms_norm_forward(
326
+ C, ctx.c_rms_weight, None, ctx.b_c_dt_rms_eps
327
+ )
328
+ C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
329
+
330
+ # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
331
+ # backward of selective_scan_cuda with the backward of chunk).
332
+ dxz = torch.empty_like(xz) # (batch, dim, seqlen)
333
+ dx, dz = dxz.chunk(2, dim=1)
334
+ dout = rearrange(dout, "b l e -> e (b l)")
335
+ dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L)
336
+ dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = (
337
+ ops.selective_scan_bwd(
338
+ conv1d_out,
339
+ delta,
340
+ A,
341
+ B,
342
+ C,
343
+ D,
344
+ z,
345
+ delta_bias,
346
+ dout_y,
347
+ scan_intermediates,
348
+ out,
349
+ dz,
350
+ ctx.delta_softplus,
351
+ True, # option to recompute out_z
352
+ )
353
+ )
354
+ dout_proj_weight = torch.einsum(
355
+ "eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)")
356
+ )
357
+ dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None
358
+ dD = dD if D is not None else None
359
+ dx_dbl = torch.empty_like(x_dbl)
360
+ dB_proj_bias = None
361
+ if ctx.is_variable_B:
362
+ if not A.is_complex():
363
+ dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous()
364
+ else:
365
+ dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
366
+ dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None
367
+ dx_dbl[:, delta_rank:delta_rank + d_state] = dB # (bl d)
368
+ dB = None
369
+ dC_proj_bias = None
370
+ if ctx.is_variable_C:
371
+ if not A.is_complex():
372
+ dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous()
373
+ else:
374
+ dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
375
+ dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None
376
+ dx_dbl[:, -d_state:] = dC # (bl d)
377
+ dC = None
378
+ ddelta = rearrange(ddelta, "b d l -> d (b l)")
379
+ ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank])
380
+ dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight)
381
+ dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)")
382
+ dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d"))
383
+ dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out)
384
+ dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1])
385
+ # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
386
+ # backward of conv1d with the backward of chunk).
387
+ dx, dconv1d_weight, dconv1d_bias, *_ = causal_conv1d_cuda.causal_conv1d_bwd(
388
+ x, conv1d_weight, conv1d_bias, dconv1d_out, None, None, None, dx, False, True
389
+ )
390
+ dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
391
+ dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w")
392
+ return (dxz, dconv1d_weight, dconv1d_bias, dx_proj_weight, ddelta_proj_weight,
393
+ dout_proj_weight, dout_proj_bias,
394
+ dA, dB, dC, dD,
395
+ ddelta_bias if delta_bias is not None else None,
396
+ # 6-None are delta_softplus, checkpoint_lvl, b_rms_weight, c_rms_weight, dt_rms_weight, b_c_dt_rms_eps
397
+ dB_proj_bias, dC_proj_bias, None, None, None, None, None, None)
398
+
399
+
400
+ def mamba_inner_fn(
401
+ xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
402
+ out_proj_weight, out_proj_bias,
403
+ A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
404
+ C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1, b_rms_weight= None, c_rms_weight= None, dt_rms_weight= None, b_c_dt_rms_eps=1e-6
405
+ ):
406
+ return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
407
+ out_proj_weight, out_proj_bias,
408
+ A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus, checkpoint_lvl, b_rms_weight, c_rms_weight, dt_rms_weight, b_c_dt_rms_eps)
409
+
410
+
411
+ def mamba_inner_ref(
412
+ xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
413
+ out_proj_weight, out_proj_bias,
414
+ A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
415
+ C_proj_bias=None, delta_softplus=True
416
+ ):
417
+ assert causal_conv1d_fn is not None, "causal_conv1d_fn is not available. Please install causal-conv1d."
418
+ L = xz.shape[-1]
419
+ delta_rank = delta_proj_weight.shape[1]
420
+ d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
421
+ x, z = xz.chunk(2, dim=1)
422
+ x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, activation="silu")
423
+ # We're being very careful here about the layout, to avoid extra transposes.
424
+ # We want delta to have d as the slowest moving dimension
425
+ # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
426
+ x_dbl = F.linear(rearrange(x, 'b d l -> (b l) d'), x_proj_weight) # (bl d)
427
+ delta = delta_proj_weight @ x_dbl[:, :delta_rank].t()
428
+ delta = rearrange(delta, "d (b l) -> b d l", l=L)
429
+ if B is None: # variable B
430
+ B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl d)
431
+ if B_proj_bias is not None:
432
+ B = B + B_proj_bias.to(dtype=B.dtype)
433
+ if not A.is_complex():
434
+ B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
435
+ else:
436
+ B = rearrange(B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous()
437
+ if C is None: # variable B
438
+ C = x_dbl[:, -d_state:] # (bl d)
439
+ if C_proj_bias is not None:
440
+ C = C + C_proj_bias.to(dtype=C.dtype)
441
+ if not A.is_complex():
442
+ C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
443
+ else:
444
+ C = rearrange(C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous()
445
+ y = selective_scan_fn(x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True)
446
+ return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias)
build/torch210-cxx11-cu126-x86_64-linux/ops/triton/__init__.py ADDED
File without changes
build/torch210-cxx11-cu126-x86_64-linux/ops/triton/k_activations.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao, Albert Gu.
2
+
3
+ import torch
4
+
5
+ import triton
6
+ import triton.language as tl
7
+
8
+
9
+ @triton.autotune(
10
+ configs=[
11
+ triton.Config({'BLOCK_N': 32}),
12
+ triton.Config({'BLOCK_N': 64}),
13
+ triton.Config({'BLOCK_N': 128}),
14
+ triton.Config({'BLOCK_N': 256}),
15
+ triton.Config({'BLOCK_N': 512}),
16
+ triton.Config({'BLOCK_N': 1024}),
17
+ ],
18
+ key=['ncols'],
19
+ )
20
+ @triton.jit
21
+ def _swiglu_fwd_kernel(
22
+ X,
23
+ Y,
24
+ OUT,
25
+ stride_x_row, # how much to increase the pointer when moving by 1 row
26
+ stride_y_row,
27
+ stride_out_row,
28
+ ncols,
29
+ BLOCK_N: tl.constexpr,
30
+ ):
31
+ # Map the program id to the row of X and Y it should compute.
32
+ row = tl.program_id(0)
33
+ start_col = tl.program_id(1) * BLOCK_N
34
+ X += row * stride_x_row
35
+ Y += row * stride_y_row
36
+ OUT += row * stride_out_row
37
+ cols = start_col + tl.arange(0, BLOCK_N)
38
+ x = tl.load(X + cols, mask=cols < ncols, other=0.).to(tl.float32)
39
+ y = tl.load(Y + cols, mask=cols < ncols, other=0.).to(tl.float32)
40
+ out = x * tl.sigmoid(x) * y
41
+ tl.store(OUT + cols, out, mask=cols < ncols)
42
+
43
+
44
+ def _swiglu_fwd(xy, out=None):
45
+ if xy.stride(-1) != 1:
46
+ xy = xy.contiguous()
47
+ batch_shape = xy.shape[:-1]
48
+ xy = xy.reshape(-1, xy.shape[-1])
49
+ x, y = xy.chunk(2, dim=-1)
50
+ if out is None:
51
+ out = torch.empty_like(x)
52
+ else:
53
+ out = out.reshape(-1, out.shape[-1])
54
+ assert out.shape == x.shape
55
+ assert out.stride(-1) == 1
56
+ M, N = x.shape
57
+ grid = lambda META: (M, triton.cdiv(N, META['BLOCK_N']))
58
+ with torch.cuda.device(x.device.index):
59
+ _swiglu_fwd_kernel[grid](x, y, out, x.stride(0), y.stride(0), out.stride(0), N)
60
+ return out.reshape(*batch_shape, out.shape[-1])
61
+
62
+
63
+ @triton.autotune(
64
+ configs=[
65
+ triton.Config({'BLOCK_N': 32}),
66
+ triton.Config({'BLOCK_N': 64}),
67
+ triton.Config({'BLOCK_N': 128}),
68
+ triton.Config({'BLOCK_N': 256}),
69
+ triton.Config({'BLOCK_N': 512}),
70
+ triton.Config({'BLOCK_N': 1024}),
71
+ ],
72
+ key=['ncols'],
73
+ )
74
+ @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["OUT"] is not None})
75
+ @triton.jit
76
+ def _swiglu_bwd_kernel(
77
+ X,
78
+ Y,
79
+ DOUT,
80
+ OUT,
81
+ DX,
82
+ DY,
83
+ stride_x_row, # how much to increase the pointer when moving by 1 row
84
+ stride_y_row,
85
+ stride_dout_row,
86
+ stride_out_row,
87
+ stride_dx_row,
88
+ stride_dy_row,
89
+ ncols,
90
+ BLOCK_N: tl.constexpr,
91
+ RECOMPUTE_OUTPUT: tl.constexpr,
92
+ ):
93
+ # Map the program id to the row of X and Y it should compute.
94
+ row = tl.program_id(0)
95
+ start_col = tl.program_id(1) * BLOCK_N
96
+ X += row * stride_x_row
97
+ Y += row * stride_y_row
98
+ DOUT += row * stride_dout_row
99
+ if RECOMPUTE_OUTPUT:
100
+ OUT += row * stride_out_row
101
+ DX += row * stride_dx_row
102
+ DY += row * stride_dy_row
103
+ cols = start_col + tl.arange(0, BLOCK_N)
104
+ x = tl.load(X + cols, mask=cols < ncols, other=0.).to(tl.float32)
105
+ y = tl.load(Y + cols, mask=cols < ncols, other=0.).to(tl.float32)
106
+ dout = tl.load(DOUT + cols, mask=cols < ncols, other=0.).to(tl.float32)
107
+ x_sigmoid = tl.sigmoid(x)
108
+ dx = x_sigmoid * (1 + x * (1 - x_sigmoid)) * y * dout
109
+ dy = x * x_sigmoid * dout
110
+ tl.store(DX + cols, dx, mask=cols < ncols)
111
+ tl.store(DY + cols, dy, mask=cols < ncols)
112
+ if RECOMPUTE_OUTPUT:
113
+ out = x * x_sigmoid * y
114
+ tl.store(OUT + cols, out, mask=cols < ncols)
115
+
116
+
117
+ def _swiglu_bwd(xy, dout, dxy=None, recompute_output=False, out=None):
118
+ if xy.stride(-1) != 1:
119
+ xy = xy.contiguous()
120
+ if dout.stride(-1) != 1:
121
+ dout = dout.contiguous()
122
+ batch_shape = xy.shape[:-1]
123
+ xy = xy.reshape(-1, xy.shape[-1])
124
+ x, y = xy.chunk(2, dim=-1)
125
+ dout = dout.reshape(-1, dout.shape[-1])
126
+ assert dout.shape == x.shape
127
+ if dxy is None:
128
+ dxy = torch.empty_like(xy)
129
+ else:
130
+ dxy = dxy.reshape(-1, dxy.shape[-1])
131
+ assert dxy.shape == xy.shape
132
+ dx, dy = dxy.chunk(2, dim=-1)
133
+ assert dx.stride(-1) == 1
134
+ assert dy.stride(-1) == 1
135
+ if recompute_output:
136
+ if out is None:
137
+ out = torch.empty_like(x)
138
+ else:
139
+ out = out.reshape(-1, out.shape[-1])
140
+ assert out.shape == x.shape
141
+ assert out.stride(-1) == 1
142
+ M, N = x.shape
143
+ grid = lambda META: (M, triton.cdiv(N, META['BLOCK_N']))
144
+ with torch.cuda.device(x.device.index):
145
+ _swiglu_bwd_kernel[grid](x, y, dout, out if recompute_output else None, dx, dy,
146
+ x.stride(0), y.stride(0), dout.stride(0),
147
+ out.stride(0) if recompute_output else 0,
148
+ dx.stride(0), dy.stride(0),
149
+ N)
150
+ if not recompute_output:
151
+ return dxy.reshape(*batch_shape, dxy.shape[-1])
152
+ else:
153
+ return dxy.reshape(*batch_shape, dxy.shape[-1]), out.reshape(*batch_shape, out.shape[-1])
154
+
155
+
156
+ class SwiGLU(torch.autograd.Function):
157
+
158
+ @staticmethod
159
+ def forward(ctx, xy):
160
+ ctx.save_for_backward(xy)
161
+ return _swiglu_fwd(xy)
162
+
163
+ @staticmethod
164
+ def backward(ctx, dout):
165
+ xy, = ctx.saved_tensors
166
+ return _swiglu_bwd(xy, dout)
167
+
168
+
169
+ swiglu = SwiGLU.apply
build/torch210-cxx11-cu126-x86_64-linux/ops/triton/layer_norm.py ADDED
@@ -0,0 +1,1113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao.
2
+ # Implement dropout + residual + layer_norm / rms_norm.
3
+
4
+ # Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
5
+ # For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
6
+ # This is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
7
+ # The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
8
+
9
+ import math
10
+ import warnings
11
+
12
+ import torch
13
+ import torch.nn.functional as F
14
+ from ...utils.torch import custom_bwd, custom_fwd
15
+
16
+ import triton
17
+ import triton.language as tl
18
+
19
+
20
+ def layer_norm_ref(
21
+ x,
22
+ weight,
23
+ bias,
24
+ residual=None,
25
+ x1=None,
26
+ weight1=None,
27
+ bias1=None,
28
+ eps=1e-6,
29
+ dropout_p=0.0,
30
+ rowscale=None,
31
+ prenorm=False,
32
+ dropout_mask=None,
33
+ dropout_mask1=None,
34
+ upcast=False,
35
+ ):
36
+ dtype = x.dtype
37
+ if upcast:
38
+ x = x.float()
39
+ weight = weight.float()
40
+ bias = bias.float() if bias is not None else None
41
+ residual = residual.float() if residual is not None else residual
42
+ x1 = x1.float() if x1 is not None else None
43
+ weight1 = weight1.float() if weight1 is not None else None
44
+ bias1 = bias1.float() if bias1 is not None else None
45
+ if x1 is not None:
46
+ assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
47
+ if rowscale is not None:
48
+ x = x * rowscale[..., None]
49
+ if dropout_p > 0.0:
50
+ if dropout_mask is not None:
51
+ x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)
52
+ else:
53
+ x = F.dropout(x, p=dropout_p)
54
+ if x1 is not None:
55
+ if dropout_mask1 is not None:
56
+ x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p)
57
+ else:
58
+ x1 = F.dropout(x1, p=dropout_p)
59
+ if x1 is not None:
60
+ x = x + x1
61
+ if residual is not None:
62
+ x = (x + residual).to(x.dtype)
63
+ out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to(
64
+ dtype
65
+ )
66
+ if weight1 is None:
67
+ return out if not prenorm else (out, x)
68
+ else:
69
+ out1 = F.layer_norm(
70
+ x.to(weight1.dtype), x.shape[-1:], weight=weight1, bias=bias1, eps=eps
71
+ ).to(dtype)
72
+ return (out, out1) if not prenorm else (out, out1, x)
73
+
74
+
75
+ def rms_norm_ref(
76
+ x,
77
+ weight,
78
+ bias,
79
+ residual=None,
80
+ x1=None,
81
+ weight1=None,
82
+ bias1=None,
83
+ eps=1e-6,
84
+ dropout_p=0.0,
85
+ rowscale=None,
86
+ prenorm=False,
87
+ dropout_mask=None,
88
+ dropout_mask1=None,
89
+ upcast=False,
90
+ ):
91
+ dtype = x.dtype
92
+ if upcast:
93
+ x = x.float()
94
+ weight = weight.float()
95
+ bias = bias.float() if bias is not None else None
96
+ residual = residual.float() if residual is not None else residual
97
+ x1 = x1.float() if x1 is not None else None
98
+ weight1 = weight1.float() if weight1 is not None else None
99
+ bias1 = bias1.float() if bias1 is not None else None
100
+ if x1 is not None:
101
+ assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
102
+ if rowscale is not None:
103
+ x = x * rowscale[..., None]
104
+ if dropout_p > 0.0:
105
+ if dropout_mask is not None:
106
+ x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)
107
+ else:
108
+ x = F.dropout(x, p=dropout_p)
109
+ if x1 is not None:
110
+ if dropout_mask1 is not None:
111
+ x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p)
112
+ else:
113
+ x1 = F.dropout(x1, p=dropout_p)
114
+ if x1 is not None:
115
+ x = x + x1
116
+ if residual is not None:
117
+ x = (x + residual).to(x.dtype)
118
+ rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
119
+ out = ((x * rstd * weight) + bias if bias is not None else (x * rstd * weight)).to(dtype)
120
+ if weight1 is None:
121
+ return out if not prenorm else (out, x)
122
+ else:
123
+ out1 = ((x * rstd * weight1) + bias1 if bias1 is not None else (x * rstd * weight1)).to(
124
+ dtype
125
+ )
126
+ return (out, out1) if not prenorm else (out, out1, x)
127
+
128
+ def config_prune(configs):
129
+
130
+ if torch.version.hip:
131
+ try:
132
+ # set warp size based on gcn architecure
133
+ gcn_arch_name = torch.cuda.get_device_properties(0).gcnArchName
134
+ if "gfx10" in gcn_arch_name or "gfx11" in gcn_arch_name:
135
+ # radeon
136
+ warp_size = 32
137
+ else:
138
+ # instinct
139
+ warp_size = 64
140
+ except AttributeError as e:
141
+ # fall back to crude method to set warp size
142
+ device_name = torch.cuda.get_device_properties(0).name
143
+ if 'instinct' in device_name.lower():
144
+ warp_size = 64
145
+ else:
146
+ warp_size = 32
147
+ warnings.warn(f"{e}, warp size set to {warp_size} based on device name: {device_name}", UserWarning)
148
+
149
+ else:
150
+ # cuda
151
+ warp_size = 32
152
+
153
+ max_block_sz = 1024
154
+ max_num_warps = max_block_sz // warp_size
155
+ pruned_configs = [config for config in configs if config.num_warps <= max_num_warps]
156
+ return pruned_configs
157
+
158
+ configs_autotune = [
159
+ triton.Config({}, num_warps=1),
160
+ triton.Config({}, num_warps=2),
161
+ triton.Config({}, num_warps=4),
162
+ triton.Config({}, num_warps=8),
163
+ triton.Config({}, num_warps=16),
164
+ triton.Config({}, num_warps=32),
165
+ ]
166
+
167
+ pruned_configs_autotune = config_prune(configs_autotune)
168
+
169
+ @triton.autotune(
170
+ configs = pruned_configs_autotune,
171
+ key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
172
+ )
173
+ # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
174
+ # @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
175
+ @triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None})
176
+ @triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None})
177
+ @triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None})
178
+ @triton.jit
179
+ def _layer_norm_fwd_1pass_kernel(
180
+ X, # pointer to the input
181
+ Y, # pointer to the output
182
+ W, # pointer to the weights
183
+ B, # pointer to the biases
184
+ RESIDUAL, # pointer to the residual
185
+ X1,
186
+ W1,
187
+ B1,
188
+ Y1,
189
+ RESIDUAL_OUT, # pointer to the residual
190
+ ROWSCALE,
191
+ SEEDS, # Dropout seeds for each row
192
+ DROPOUT_MASK,
193
+ Mean, # pointer to the mean
194
+ Rstd, # pointer to the 1/std
195
+ stride_x_row, # how much to increase the pointer when moving by 1 row
196
+ stride_y_row,
197
+ stride_res_row,
198
+ stride_res_out_row,
199
+ stride_x1_row,
200
+ stride_y1_row,
201
+ M, # number of rows in X
202
+ N, # number of columns in X
203
+ eps, # epsilon to avoid division by zero
204
+ dropout_p, # Dropout probability
205
+ IS_RMS_NORM: tl.constexpr,
206
+ BLOCK_N: tl.constexpr,
207
+ HAS_RESIDUAL: tl.constexpr,
208
+ STORE_RESIDUAL_OUT: tl.constexpr,
209
+ HAS_BIAS: tl.constexpr,
210
+ HAS_DROPOUT: tl.constexpr,
211
+ STORE_DROPOUT_MASK: tl.constexpr,
212
+ HAS_ROWSCALE: tl.constexpr,
213
+ HAS_X1: tl.constexpr,
214
+ HAS_W1: tl.constexpr,
215
+ HAS_B1: tl.constexpr,
216
+ ):
217
+ # Map the program id to the row of X and Y it should compute.
218
+ row = tl.program_id(0)
219
+ X += row * stride_x_row
220
+ Y += row * stride_y_row
221
+ if HAS_RESIDUAL:
222
+ RESIDUAL += row * stride_res_row
223
+ if STORE_RESIDUAL_OUT:
224
+ RESIDUAL_OUT += row * stride_res_out_row
225
+ if HAS_X1:
226
+ X1 += row * stride_x1_row
227
+ if HAS_W1:
228
+ Y1 += row * stride_y1_row
229
+ # Compute mean and variance
230
+ cols = tl.arange(0, BLOCK_N)
231
+ x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
232
+ if HAS_ROWSCALE:
233
+ rowscale = tl.load(ROWSCALE + row).to(tl.float32)
234
+ x *= rowscale
235
+ if HAS_DROPOUT:
236
+ # Compute dropout mask
237
+ # 7 rounds is good enough, and reduces register pressure
238
+ keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
239
+ x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0)
240
+ if STORE_DROPOUT_MASK:
241
+ tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N)
242
+ if HAS_X1:
243
+ x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32)
244
+ if HAS_ROWSCALE:
245
+ rowscale = tl.load(ROWSCALE + M + row).to(tl.float32)
246
+ x1 *= rowscale
247
+ if HAS_DROPOUT:
248
+ # Compute dropout mask
249
+ # 7 rounds is good enough, and reduces register pressure
250
+ keep_mask = (
251
+ tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
252
+ )
253
+ x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0)
254
+ if STORE_DROPOUT_MASK:
255
+ tl.store(DROPOUT_MASK + (M + row) * N + cols, keep_mask, mask=cols < N)
256
+ x += x1
257
+ if HAS_RESIDUAL:
258
+ residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)
259
+ x += residual
260
+ if STORE_RESIDUAL_OUT:
261
+ tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)
262
+ if not IS_RMS_NORM:
263
+ mean = tl.sum(x, axis=0) / N
264
+ tl.store(Mean + row, mean)
265
+ xbar = tl.where(cols < N, x - mean, 0.0)
266
+ var = tl.sum(xbar * xbar, axis=0) / N
267
+ else:
268
+ xbar = tl.where(cols < N, x, 0.0)
269
+ var = tl.sum(xbar * xbar, axis=0) / N
270
+ rstd = 1 / tl.sqrt(var + eps)
271
+ tl.store(Rstd + row, rstd)
272
+ # Normalize and apply linear transformation
273
+ mask = cols < N
274
+ w = tl.load(W + cols, mask=mask).to(tl.float32)
275
+ if HAS_BIAS:
276
+ b = tl.load(B + cols, mask=mask).to(tl.float32)
277
+ x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
278
+ y = x_hat * w + b if HAS_BIAS else x_hat * w
279
+ # Write output
280
+ tl.store(Y + cols, y, mask=mask)
281
+ if HAS_W1:
282
+ w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
283
+ if HAS_B1:
284
+ b1 = tl.load(B1 + cols, mask=mask).to(tl.float32)
285
+ y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1
286
+ tl.store(Y1 + cols, y1, mask=mask)
287
+
288
+
289
+ def _layer_norm_fwd(
290
+ x,
291
+ weight,
292
+ bias,
293
+ eps,
294
+ residual=None,
295
+ x1=None,
296
+ weight1=None,
297
+ bias1=None,
298
+ dropout_p=0.0,
299
+ rowscale=None,
300
+ out_dtype=None,
301
+ residual_dtype=None,
302
+ is_rms_norm=False,
303
+ return_dropout_mask=False,
304
+ ):
305
+ if residual is not None:
306
+ residual_dtype = residual.dtype
307
+ M, N = x.shape
308
+ assert x.stride(-1) == 1
309
+ if residual is not None:
310
+ assert residual.stride(-1) == 1
311
+ assert residual.shape == (M, N)
312
+ assert weight.shape == (N,)
313
+ assert weight.stride(-1) == 1
314
+ if bias is not None:
315
+ assert bias.stride(-1) == 1
316
+ assert bias.shape == (N,)
317
+ if x1 is not None:
318
+ assert x1.shape == x.shape
319
+ assert rowscale is None
320
+ assert x1.stride(-1) == 1
321
+ if weight1 is not None:
322
+ assert weight1.shape == (N,)
323
+ assert weight1.stride(-1) == 1
324
+ if bias1 is not None:
325
+ assert bias1.shape == (N,)
326
+ assert bias1.stride(-1) == 1
327
+ if rowscale is not None:
328
+ assert rowscale.is_contiguous()
329
+ assert rowscale.shape == (M,)
330
+ # allocate output
331
+ y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
332
+ assert y.stride(-1) == 1
333
+ if weight1 is not None:
334
+ y1 = torch.empty_like(y)
335
+ assert y1.stride(-1) == 1
336
+ else:
337
+ y1 = None
338
+ if (
339
+ residual is not None
340
+ or (residual_dtype is not None and residual_dtype != x.dtype)
341
+ or dropout_p > 0.0
342
+ or rowscale is not None
343
+ or x1 is not None
344
+ ):
345
+ residual_out = torch.empty(
346
+ M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype
347
+ )
348
+ assert residual_out.stride(-1) == 1
349
+ else:
350
+ residual_out = None
351
+ mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None
352
+ rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
353
+ if dropout_p > 0.0:
354
+ seeds = torch.randint(
355
+ 2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64
356
+ )
357
+ else:
358
+ seeds = None
359
+ if return_dropout_mask and dropout_p > 0.0:
360
+ dropout_mask = torch.empty(M if x1 is None else 2 * M, N, device=x.device, dtype=torch.bool)
361
+ else:
362
+ dropout_mask = None
363
+ # Less than 64KB per feature: enqueue fused kernel
364
+ MAX_FUSED_SIZE = 65536 // x.element_size()
365
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
366
+ if N > BLOCK_N:
367
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
368
+ with torch.cuda.device(x.device.index):
369
+ _layer_norm_fwd_1pass_kernel[(M,)](
370
+ x,
371
+ y,
372
+ weight,
373
+ bias,
374
+ residual,
375
+ x1,
376
+ weight1,
377
+ bias1,
378
+ y1,
379
+ residual_out,
380
+ rowscale,
381
+ seeds,
382
+ dropout_mask,
383
+ mean,
384
+ rstd,
385
+ x.stride(0),
386
+ y.stride(0),
387
+ residual.stride(0) if residual is not None else 0,
388
+ residual_out.stride(0) if residual_out is not None else 0,
389
+ x1.stride(0) if x1 is not None else 0,
390
+ y1.stride(0) if y1 is not None else 0,
391
+ M,
392
+ N,
393
+ eps,
394
+ dropout_p,
395
+ is_rms_norm,
396
+ BLOCK_N,
397
+ residual is not None,
398
+ residual_out is not None,
399
+ bias is not None,
400
+ dropout_p > 0.0,
401
+ dropout_mask is not None,
402
+ rowscale is not None,
403
+ )
404
+ # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0
405
+ if dropout_mask is not None and x1 is not None:
406
+ dropout_mask, dropout_mask1 = dropout_mask.tensor_split(2, dim=0)
407
+ else:
408
+ dropout_mask1 = None
409
+ return (
410
+ y,
411
+ y1,
412
+ mean,
413
+ rstd,
414
+ residual_out if residual_out is not None else x,
415
+ seeds,
416
+ dropout_mask,
417
+ dropout_mask1,
418
+ )
419
+
420
+
421
+ @triton.autotune(
422
+ configs=pruned_configs_autotune,
423
+ key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS", "HAS_DROPOUT"],
424
+ )
425
+ # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
426
+ # @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
427
+ # @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
428
+ @triton.heuristics({"HAS_ROWSCALE": lambda args: args["ROWSCALE"] is not None})
429
+ @triton.heuristics({"HAS_DY1": lambda args: args["DY1"] is not None})
430
+ @triton.heuristics({"HAS_DX1": lambda args: args["DX1"] is not None})
431
+ @triton.heuristics({"HAS_B1": lambda args: args["DB1"] is not None})
432
+ @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
433
+ @triton.jit
434
+ def _layer_norm_bwd_kernel(
435
+ X, # pointer to the input
436
+ W, # pointer to the weights
437
+ B, # pointer to the biases
438
+ Y, # pointer to the output to be recomputed
439
+ DY, # pointer to the output gradient
440
+ DX, # pointer to the input gradient
441
+ DW, # pointer to the partial sum of weights gradient
442
+ DB, # pointer to the partial sum of biases gradient
443
+ DRESIDUAL,
444
+ W1,
445
+ DY1,
446
+ DX1,
447
+ DW1,
448
+ DB1,
449
+ DRESIDUAL_IN,
450
+ ROWSCALE,
451
+ SEEDS,
452
+ Mean, # pointer to the mean
453
+ Rstd, # pointer to the 1/std
454
+ stride_x_row, # how much to increase the pointer when moving by 1 row
455
+ stride_y_row,
456
+ stride_dy_row,
457
+ stride_dx_row,
458
+ stride_dres_row,
459
+ stride_dy1_row,
460
+ stride_dx1_row,
461
+ stride_dres_in_row,
462
+ M, # number of rows in X
463
+ N, # number of columns in X
464
+ eps, # epsilon to avoid division by zero
465
+ dropout_p,
466
+ rows_per_program,
467
+ IS_RMS_NORM: tl.constexpr,
468
+ BLOCK_N: tl.constexpr,
469
+ HAS_DRESIDUAL: tl.constexpr,
470
+ STORE_DRESIDUAL: tl.constexpr,
471
+ HAS_BIAS: tl.constexpr,
472
+ HAS_DROPOUT: tl.constexpr,
473
+ HAS_ROWSCALE: tl.constexpr,
474
+ HAS_DY1: tl.constexpr,
475
+ HAS_DX1: tl.constexpr,
476
+ HAS_B1: tl.constexpr,
477
+ RECOMPUTE_OUTPUT: tl.constexpr,
478
+ ):
479
+ # Map the program id to the elements of X, DX, and DY it should compute.
480
+ row_block_id = tl.program_id(0)
481
+ row_start = row_block_id * rows_per_program
482
+ # Do not early exit if row_start >= M, because we need to write DW and DB
483
+ cols = tl.arange(0, BLOCK_N)
484
+ mask = cols < N
485
+ X += row_start * stride_x_row
486
+ if HAS_DRESIDUAL:
487
+ DRESIDUAL += row_start * stride_dres_row
488
+ if STORE_DRESIDUAL:
489
+ DRESIDUAL_IN += row_start * stride_dres_in_row
490
+ DY += row_start * stride_dy_row
491
+ DX += row_start * stride_dx_row
492
+ if HAS_DY1:
493
+ DY1 += row_start * stride_dy1_row
494
+ if HAS_DX1:
495
+ DX1 += row_start * stride_dx1_row
496
+ if RECOMPUTE_OUTPUT:
497
+ Y += row_start * stride_y_row
498
+ w = tl.load(W + cols, mask=mask).to(tl.float32)
499
+ if RECOMPUTE_OUTPUT and HAS_BIAS:
500
+ b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)
501
+ if HAS_DY1:
502
+ w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
503
+ dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
504
+ if HAS_BIAS:
505
+ db = tl.zeros((BLOCK_N,), dtype=tl.float32)
506
+ if HAS_DY1:
507
+ dw1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
508
+ if HAS_B1:
509
+ db1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
510
+ row_end = min((row_block_id + 1) * rows_per_program, M)
511
+ for row in range(row_start, row_end):
512
+ # Load data to SRAM
513
+ x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
514
+ dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
515
+ if HAS_DY1:
516
+ dy1 = tl.load(DY1 + cols, mask=mask, other=0).to(tl.float32)
517
+ if not IS_RMS_NORM:
518
+ mean = tl.load(Mean + row)
519
+ rstd = tl.load(Rstd + row)
520
+ # Compute dx
521
+ xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
522
+ xhat = tl.where(mask, xhat, 0.0)
523
+ if RECOMPUTE_OUTPUT:
524
+ y = xhat * w + b if HAS_BIAS else xhat * w
525
+ tl.store(Y + cols, y, mask=mask)
526
+ wdy = w * dy
527
+ dw += dy * xhat
528
+ if HAS_BIAS:
529
+ db += dy
530
+ if HAS_DY1:
531
+ wdy += w1 * dy1
532
+ dw1 += dy1 * xhat
533
+ if HAS_B1:
534
+ db1 += dy1
535
+ if not IS_RMS_NORM:
536
+ c1 = tl.sum(xhat * wdy, axis=0) / N
537
+ c2 = tl.sum(wdy, axis=0) / N
538
+ dx = (wdy - (xhat * c1 + c2)) * rstd
539
+ else:
540
+ c1 = tl.sum(xhat * wdy, axis=0) / N
541
+ dx = (wdy - xhat * c1) * rstd
542
+ if HAS_DRESIDUAL:
543
+ dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32)
544
+ dx += dres
545
+ # Write dx
546
+ if STORE_DRESIDUAL:
547
+ tl.store(DRESIDUAL_IN + cols, dx, mask=mask)
548
+ if HAS_DX1:
549
+ if HAS_DROPOUT:
550
+ keep_mask = (
551
+ tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
552
+ )
553
+ dx1 = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
554
+ else:
555
+ dx1 = dx
556
+ tl.store(DX1 + cols, dx1, mask=mask)
557
+ if HAS_DROPOUT:
558
+ keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
559
+ dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
560
+ if HAS_ROWSCALE:
561
+ rowscale = tl.load(ROWSCALE + row).to(tl.float32)
562
+ dx *= rowscale
563
+ tl.store(DX + cols, dx, mask=mask)
564
+
565
+ X += stride_x_row
566
+ if HAS_DRESIDUAL:
567
+ DRESIDUAL += stride_dres_row
568
+ if STORE_DRESIDUAL:
569
+ DRESIDUAL_IN += stride_dres_in_row
570
+ if RECOMPUTE_OUTPUT:
571
+ Y += stride_y_row
572
+ DY += stride_dy_row
573
+ DX += stride_dx_row
574
+ if HAS_DY1:
575
+ DY1 += stride_dy1_row
576
+ if HAS_DX1:
577
+ DX1 += stride_dx1_row
578
+ tl.store(DW + row_block_id * N + cols, dw, mask=mask)
579
+ if HAS_BIAS:
580
+ tl.store(DB + row_block_id * N + cols, db, mask=mask)
581
+ if HAS_DY1:
582
+ tl.store(DW1 + row_block_id * N + cols, dw1, mask=mask)
583
+ if HAS_B1:
584
+ tl.store(DB1 + row_block_id * N + cols, db1, mask=mask)
585
+
586
+
587
+ def _layer_norm_bwd(
588
+ dy,
589
+ x,
590
+ weight,
591
+ bias,
592
+ eps,
593
+ mean,
594
+ rstd,
595
+ dresidual=None,
596
+ dy1=None,
597
+ weight1=None,
598
+ bias1=None,
599
+ seeds=None,
600
+ dropout_p=0.0,
601
+ rowscale=None,
602
+ has_residual=False,
603
+ has_x1=False,
604
+ is_rms_norm=False,
605
+ x_dtype=None,
606
+ recompute_output=False,
607
+ ):
608
+ M, N = x.shape
609
+ assert x.stride(-1) == 1
610
+ assert dy.stride(-1) == 1
611
+ assert dy.shape == (M, N)
612
+ if dresidual is not None:
613
+ assert dresidual.stride(-1) == 1
614
+ assert dresidual.shape == (M, N)
615
+ assert weight.shape == (N,)
616
+ assert weight.stride(-1) == 1
617
+ if bias is not None:
618
+ assert bias.stride(-1) == 1
619
+ assert bias.shape == (N,)
620
+ if dy1 is not None:
621
+ assert weight1 is not None
622
+ assert dy1.shape == dy.shape
623
+ assert dy1.stride(-1) == 1
624
+ if weight1 is not None:
625
+ assert weight1.shape == (N,)
626
+ assert weight1.stride(-1) == 1
627
+ if bias1 is not None:
628
+ assert bias1.shape == (N,)
629
+ assert bias1.stride(-1) == 1
630
+ if seeds is not None:
631
+ assert seeds.is_contiguous()
632
+ assert seeds.shape == (M if not has_x1 else M * 2,)
633
+ if rowscale is not None:
634
+ assert rowscale.is_contiguous()
635
+ assert rowscale.shape == (M,)
636
+ # allocate output
637
+ dx = (
638
+ torch.empty_like(x)
639
+ if x_dtype is None
640
+ else torch.empty(M, N, dtype=x_dtype, device=x.device)
641
+ )
642
+ dresidual_in = (
643
+ torch.empty_like(x)
644
+ if has_residual
645
+ and (dx.dtype != x.dtype or dropout_p > 0.0 or rowscale is not None or has_x1)
646
+ else None
647
+ )
648
+ dx1 = torch.empty_like(dx) if (has_x1 and dropout_p > 0.0) else None
649
+ y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None
650
+ if recompute_output:
651
+ assert weight1 is None, "recompute_output is not supported with parallel LayerNorm"
652
+
653
+ # Less than 64KB per feature: enqueue fused kernel
654
+ MAX_FUSED_SIZE = 65536 // x.element_size()
655
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
656
+ if N > BLOCK_N:
657
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
658
+ sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
659
+ _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)
660
+ _db = (
661
+ torch.empty((sm_count, N), dtype=torch.float32, device=bias.device)
662
+ if bias is not None
663
+ else None
664
+ )
665
+ _dw1 = torch.empty_like(_dw) if weight1 is not None else None
666
+ _db1 = torch.empty_like(_db) if bias1 is not None else None
667
+ rows_per_program = math.ceil(M / sm_count)
668
+ grid = (sm_count,)
669
+ with torch.cuda.device(x.device.index):
670
+ _layer_norm_bwd_kernel[grid](
671
+ x,
672
+ weight,
673
+ bias,
674
+ y,
675
+ dy,
676
+ dx,
677
+ _dw,
678
+ _db,
679
+ dresidual,
680
+ weight1,
681
+ dy1,
682
+ dx1,
683
+ _dw1,
684
+ _db1,
685
+ dresidual_in,
686
+ rowscale,
687
+ seeds,
688
+ mean,
689
+ rstd,
690
+ x.stride(0),
691
+ 0 if not recompute_output else y.stride(0),
692
+ dy.stride(0),
693
+ dx.stride(0),
694
+ dresidual.stride(0) if dresidual is not None else 0,
695
+ dy1.stride(0) if dy1 is not None else 0,
696
+ dx1.stride(0) if dx1 is not None else 0,
697
+ dresidual_in.stride(0) if dresidual_in is not None else 0,
698
+ M,
699
+ N,
700
+ eps,
701
+ dropout_p,
702
+ rows_per_program,
703
+ is_rms_norm,
704
+ BLOCK_N,
705
+ dresidual is not None,
706
+ dresidual_in is not None,
707
+ bias is not None,
708
+ dropout_p > 0.0,
709
+ )
710
+ dw = _dw.sum(0).to(weight.dtype)
711
+ db = _db.sum(0).to(bias.dtype) if bias is not None else None
712
+ dw1 = _dw1.sum(0).to(weight1.dtype) if weight1 is not None else None
713
+ db1 = _db1.sum(0).to(bias1.dtype) if bias1 is not None else None
714
+ # Don't need to compute dresidual_in separately in this case
715
+ if has_residual and dx.dtype == x.dtype and dropout_p == 0.0 and rowscale is None:
716
+ dresidual_in = dx
717
+ if has_x1 and dropout_p == 0.0:
718
+ dx1 = dx
719
+ return (
720
+ (dx, dw, db, dresidual_in, dx1, dw1, db1)
721
+ if not recompute_output
722
+ else (dx, dw, db, dresidual_in, dx1, dw1, db1, y)
723
+ )
724
+
725
+
726
+ class LayerNormFn(torch.autograd.Function):
727
+ @staticmethod
728
+ def forward(
729
+ ctx,
730
+ x,
731
+ weight,
732
+ bias,
733
+ residual=None,
734
+ x1=None,
735
+ weight1=None,
736
+ bias1=None,
737
+ eps=1e-6,
738
+ dropout_p=0.0,
739
+ rowscale=None,
740
+ prenorm=False,
741
+ residual_in_fp32=False,
742
+ is_rms_norm=False,
743
+ return_dropout_mask=False,
744
+ ):
745
+ x_shape_og = x.shape
746
+ # reshape input data into 2D tensor
747
+ x = x.reshape(-1, x.shape[-1])
748
+ if x.stride(-1) != 1:
749
+ x = x.contiguous()
750
+ if residual is not None:
751
+ assert residual.shape == x_shape_og
752
+ residual = residual.reshape(-1, residual.shape[-1])
753
+ if residual.stride(-1) != 1:
754
+ residual = residual.contiguous()
755
+ if x1 is not None:
756
+ assert x1.shape == x_shape_og
757
+ assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
758
+ x1 = x1.reshape(-1, x1.shape[-1])
759
+ if x1.stride(-1) != 1:
760
+ x1 = x1.contiguous()
761
+ weight = weight.contiguous()
762
+ if bias is not None:
763
+ bias = bias.contiguous()
764
+ if weight1 is not None:
765
+ weight1 = weight1.contiguous()
766
+ if bias1 is not None:
767
+ bias1 = bias1.contiguous()
768
+ if rowscale is not None:
769
+ rowscale = rowscale.reshape(-1).contiguous()
770
+ residual_dtype = (
771
+ residual.dtype
772
+ if residual is not None
773
+ else (torch.float32 if residual_in_fp32 else None)
774
+ )
775
+ y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd(
776
+ x,
777
+ weight,
778
+ bias,
779
+ eps,
780
+ residual,
781
+ x1,
782
+ weight1,
783
+ bias1,
784
+ dropout_p=dropout_p,
785
+ rowscale=rowscale,
786
+ residual_dtype=residual_dtype,
787
+ is_rms_norm=is_rms_norm,
788
+ return_dropout_mask=return_dropout_mask,
789
+ )
790
+ ctx.save_for_backward(
791
+ residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd
792
+ )
793
+ ctx.x_shape_og = x_shape_og
794
+ ctx.eps = eps
795
+ ctx.dropout_p = dropout_p
796
+ ctx.is_rms_norm = is_rms_norm
797
+ ctx.has_residual = residual is not None
798
+ ctx.has_x1 = x1 is not None
799
+ ctx.prenorm = prenorm
800
+ ctx.x_dtype = x.dtype
801
+ y = y.reshape(x_shape_og)
802
+ y1 = y1.reshape(x_shape_og) if y1 is not None else None
803
+ residual_out = residual_out.reshape(x_shape_og) if residual_out is not None else None
804
+ dropout_mask = dropout_mask.reshape(x_shape_og) if dropout_mask is not None else None
805
+ dropout_mask1 = dropout_mask1.reshape(x_shape_og) if dropout_mask1 is not None else None
806
+ if not return_dropout_mask:
807
+ if weight1 is None:
808
+ return y if not prenorm else (y, residual_out)
809
+ else:
810
+ return (y, y1) if not prenorm else (y, y1, residual_out)
811
+ else:
812
+ if weight1 is None:
813
+ return (
814
+ (y, dropout_mask, dropout_mask1)
815
+ if not prenorm
816
+ else (y, residual_out, dropout_mask, dropout_mask1)
817
+ )
818
+ else:
819
+ return (
820
+ (y, y1, dropout_mask, dropout_mask1)
821
+ if not prenorm
822
+ else (y, y1, residual_out, dropout_mask, dropout_mask1)
823
+ )
824
+
825
+ @staticmethod
826
+ def backward(ctx, dy, *args):
827
+ x, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd = ctx.saved_tensors
828
+ dy = dy.reshape(-1, dy.shape[-1])
829
+ if dy.stride(-1) != 1:
830
+ dy = dy.contiguous()
831
+ assert dy.shape == x.shape
832
+ if weight1 is not None:
833
+ dy1, args = args[0], args[1:]
834
+ dy1 = dy1.reshape(-1, dy1.shape[-1])
835
+ if dy1.stride(-1) != 1:
836
+ dy1 = dy1.contiguous()
837
+ assert dy1.shape == x.shape
838
+ else:
839
+ dy1 = None
840
+ if ctx.prenorm:
841
+ dresidual = args[0]
842
+ dresidual = dresidual.reshape(-1, dresidual.shape[-1])
843
+ if dresidual.stride(-1) != 1:
844
+ dresidual = dresidual.contiguous()
845
+ assert dresidual.shape == x.shape
846
+ else:
847
+ dresidual = None
848
+ dx, dw, db, dresidual_in, dx1, dw1, db1 = _layer_norm_bwd(
849
+ dy,
850
+ x,
851
+ weight,
852
+ bias,
853
+ ctx.eps,
854
+ mean,
855
+ rstd,
856
+ dresidual,
857
+ dy1,
858
+ weight1,
859
+ bias1,
860
+ seeds,
861
+ ctx.dropout_p,
862
+ rowscale,
863
+ ctx.has_residual,
864
+ ctx.has_x1,
865
+ ctx.is_rms_norm,
866
+ x_dtype=ctx.x_dtype,
867
+ )
868
+ return (
869
+ dx.reshape(ctx.x_shape_og),
870
+ dw,
871
+ db,
872
+ dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
873
+ dx1.reshape(ctx.x_shape_og) if dx1 is not None else None,
874
+ dw1,
875
+ db1,
876
+ None,
877
+ None,
878
+ None,
879
+ None,
880
+ None,
881
+ None,
882
+ None,
883
+ )
884
+
885
+
886
+ def layer_norm_fn(
887
+ x,
888
+ weight,
889
+ bias,
890
+ residual=None,
891
+ x1=None,
892
+ weight1=None,
893
+ bias1=None,
894
+ eps=1e-6,
895
+ dropout_p=0.0,
896
+ rowscale=None,
897
+ prenorm=False,
898
+ residual_in_fp32=False,
899
+ is_rms_norm=False,
900
+ return_dropout_mask=False,
901
+ ):
902
+ return LayerNormFn.apply(
903
+ x,
904
+ weight,
905
+ bias,
906
+ residual,
907
+ x1,
908
+ weight1,
909
+ bias1,
910
+ eps,
911
+ dropout_p,
912
+ rowscale,
913
+ prenorm,
914
+ residual_in_fp32,
915
+ is_rms_norm,
916
+ return_dropout_mask,
917
+ )
918
+
919
+
920
+ def rms_norm_fn(
921
+ x,
922
+ weight,
923
+ bias,
924
+ residual=None,
925
+ x1=None,
926
+ weight1=None,
927
+ bias1=None,
928
+ eps=1e-6,
929
+ dropout_p=0.0,
930
+ rowscale=None,
931
+ prenorm=False,
932
+ residual_in_fp32=False,
933
+ return_dropout_mask=False,
934
+ ):
935
+ return LayerNormFn.apply(
936
+ x,
937
+ weight,
938
+ bias,
939
+ residual,
940
+ x1,
941
+ weight1,
942
+ bias1,
943
+ eps,
944
+ dropout_p,
945
+ rowscale,
946
+ prenorm,
947
+ residual_in_fp32,
948
+ True,
949
+ return_dropout_mask,
950
+ )
951
+
952
+
953
+ class RMSNorm(torch.nn.Module):
954
+
955
+ def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, device=None, dtype=None):
956
+ factory_kwargs = {"device": device, "dtype": dtype}
957
+ super().__init__()
958
+ self.eps = eps
959
+ if dropout_p > 0.0:
960
+ self.drop = torch.nn.Dropout(dropout_p)
961
+ else:
962
+ self.drop = None
963
+ self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
964
+ self.register_parameter("bias", None)
965
+ self.reset_parameters()
966
+
967
+ def reset_parameters(self):
968
+ torch.nn.init.ones_(self.weight)
969
+
970
+ def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
971
+ return rms_norm_fn(
972
+ x,
973
+ self.weight,
974
+ self.bias,
975
+ residual=residual,
976
+ eps=self.eps,
977
+ dropout_p=self.drop.p if self.drop is not None and self.training else 0.0,
978
+ prenorm=prenorm,
979
+ residual_in_fp32=residual_in_fp32,
980
+ )
981
+
982
+
983
+ class LayerNormLinearFn(torch.autograd.Function):
984
+ @staticmethod
985
+ @custom_fwd
986
+ def forward(
987
+ ctx,
988
+ x,
989
+ norm_weight,
990
+ norm_bias,
991
+ linear_weight,
992
+ linear_bias,
993
+ residual=None,
994
+ eps=1e-6,
995
+ prenorm=False,
996
+ residual_in_fp32=False,
997
+ is_rms_norm=False,
998
+ ):
999
+ x_shape_og = x.shape
1000
+ # reshape input data into 2D tensor
1001
+ x = x.reshape(-1, x.shape[-1])
1002
+ if x.stride(-1) != 1:
1003
+ x = x.contiguous()
1004
+ if residual is not None:
1005
+ assert residual.shape == x_shape_og
1006
+ residual = residual.reshape(-1, residual.shape[-1])
1007
+ if residual.stride(-1) != 1:
1008
+ residual = residual.contiguous()
1009
+ norm_weight = norm_weight.contiguous()
1010
+ if norm_bias is not None:
1011
+ norm_bias = norm_bias.contiguous()
1012
+ residual_dtype = (
1013
+ residual.dtype
1014
+ if residual is not None
1015
+ else (torch.float32 if residual_in_fp32 else None)
1016
+ )
1017
+ y, _, mean, rstd, residual_out, *rest = _layer_norm_fwd(
1018
+ x,
1019
+ norm_weight,
1020
+ norm_bias,
1021
+ eps,
1022
+ residual,
1023
+ out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype(),
1024
+ residual_dtype=residual_dtype,
1025
+ is_rms_norm=is_rms_norm,
1026
+ )
1027
+ y = y.reshape(x_shape_og)
1028
+ dtype = torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype
1029
+ linear_weight = linear_weight.to(dtype)
1030
+ linear_bias = linear_bias.to(dtype) if linear_bias is not None else None
1031
+ out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias)
1032
+ # We don't store y, will be recomputed in the backward pass to save memory
1033
+ ctx.save_for_backward(residual_out, norm_weight, norm_bias, linear_weight, mean, rstd)
1034
+ ctx.x_shape_og = x_shape_og
1035
+ ctx.eps = eps
1036
+ ctx.is_rms_norm = is_rms_norm
1037
+ ctx.has_residual = residual is not None
1038
+ ctx.prenorm = prenorm
1039
+ ctx.x_dtype = x.dtype
1040
+ ctx.linear_bias_is_none = linear_bias is None
1041
+ return out if not prenorm else (out, residual_out.reshape(x_shape_og))
1042
+
1043
+ @staticmethod
1044
+ @custom_bwd
1045
+ def backward(ctx, dout, *args):
1046
+ x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
1047
+ dout = dout.reshape(-1, dout.shape[-1])
1048
+ dy = F.linear(dout, linear_weight.t())
1049
+ dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
1050
+ if dy.stride(-1) != 1:
1051
+ dy = dy.contiguous()
1052
+ assert dy.shape == x.shape
1053
+ if ctx.prenorm:
1054
+ dresidual = args[0]
1055
+ dresidual = dresidual.reshape(-1, dresidual.shape[-1])
1056
+ if dresidual.stride(-1) != 1:
1057
+ dresidual = dresidual.contiguous()
1058
+ assert dresidual.shape == x.shape
1059
+ else:
1060
+ dresidual = None
1061
+ dx, dnorm_weight, dnorm_bias, dresidual_in, _, _, _, y = _layer_norm_bwd(
1062
+ dy,
1063
+ x,
1064
+ norm_weight,
1065
+ norm_bias,
1066
+ ctx.eps,
1067
+ mean,
1068
+ rstd,
1069
+ dresidual=dresidual,
1070
+ has_residual=ctx.has_residual,
1071
+ is_rms_norm=ctx.is_rms_norm,
1072
+ x_dtype=ctx.x_dtype,
1073
+ recompute_output=True,
1074
+ )
1075
+ dlinear_weight = torch.einsum("bo,bi->oi", dout, y)
1076
+ return (
1077
+ dx.reshape(ctx.x_shape_og),
1078
+ dnorm_weight,
1079
+ dnorm_bias,
1080
+ dlinear_weight,
1081
+ dlinear_bias,
1082
+ dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
1083
+ None,
1084
+ None,
1085
+ None,
1086
+ None,
1087
+ )
1088
+
1089
+
1090
+ def layer_norm_linear_fn(
1091
+ x,
1092
+ norm_weight,
1093
+ norm_bias,
1094
+ linear_weight,
1095
+ linear_bias,
1096
+ residual=None,
1097
+ eps=1e-6,
1098
+ prenorm=False,
1099
+ residual_in_fp32=False,
1100
+ is_rms_norm=False,
1101
+ ):
1102
+ return LayerNormLinearFn.apply(
1103
+ x,
1104
+ norm_weight,
1105
+ norm_bias,
1106
+ linear_weight,
1107
+ linear_bias,
1108
+ residual,
1109
+ eps,
1110
+ prenorm,
1111
+ residual_in_fp32,
1112
+ is_rms_norm,
1113
+ )
build/torch210-cxx11-cu126-x86_64-linux/ops/triton/layernorm_gated.py ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao.
2
+ # Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
3
+ # For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
4
+ # This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
5
+ # The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
6
+
7
+ import math
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+
12
+ import triton
13
+ import triton.language as tl
14
+
15
+ from einops import rearrange
16
+
17
+
18
+ def rms_norm_ref(x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True, upcast=True):
19
+ dtype = x.dtype
20
+ N = x.shape[-1]
21
+ weight = weight.float()
22
+ bias = bias.float() if bias is not None else None
23
+ if upcast:
24
+ x = x.float()
25
+ z = z.float() if z is not None else z
26
+ if z is not None and not norm_before_gate:
27
+ x = x * F.silu(z)
28
+ if group_size is None:
29
+ rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
30
+ out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight)
31
+ else:
32
+ x_group = rearrange(x, "... (g d) -> ... g d", d=group_size)
33
+ rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) + eps)
34
+ out = rearrange(x_group * rstd, "... g d -> ... (g d)") * weight
35
+ if bias is not None:
36
+ out = out + bias
37
+ if z is not None and norm_before_gate:
38
+ out *= F.silu(z)
39
+ return out.to(dtype)
40
+
41
+
42
+ @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
43
+ @triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None})
44
+ @triton.jit
45
+ def _layer_norm_fwd_1pass_kernel(
46
+ X, # pointer to the input
47
+ Y, # pointer to the output
48
+ W, # pointer to the weights
49
+ B, # pointer to the biases
50
+ Z, # pointer to the other branch
51
+ Mean, # pointer to the mean
52
+ Rstd, # pointer to the 1/std
53
+ stride_x_row, # how much to increase the pointer when moving by 1 row
54
+ stride_y_row,
55
+ stride_z_row,
56
+ M, # number of rows in X
57
+ N, # number of columns in X
58
+ eps, # epsilon to avoid division by zero
59
+ BLOCK_N: tl.constexpr,
60
+ HAS_BIAS: tl.constexpr,
61
+ HAS_Z: tl.constexpr,
62
+ NORM_BEFORE_GATE: tl.constexpr,
63
+ IS_RMS_NORM: tl.constexpr,
64
+ ):
65
+ # Map the program id to the row of X and Y it should compute.
66
+ row = tl.program_id(0)
67
+ group = tl.program_id(1)
68
+ X += row * stride_x_row + group * N
69
+ Y += row * stride_y_row + group * N
70
+ if HAS_Z:
71
+ Z += row * stride_z_row + group * N
72
+ if not IS_RMS_NORM:
73
+ Mean += group * M
74
+ Rstd += group * M
75
+ W += group * N
76
+ if HAS_BIAS:
77
+ B += group * N
78
+ # Compute mean and variance
79
+ cols = tl.arange(0, BLOCK_N)
80
+ x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
81
+ if HAS_Z and not NORM_BEFORE_GATE:
82
+ z = tl.load(Z + cols, mask=cols < N).to(tl.float32)
83
+ x *= z * tl.sigmoid(z)
84
+ if not IS_RMS_NORM:
85
+ mean = tl.sum(x, axis=0) / N
86
+ tl.store(Mean + row, mean)
87
+ xbar = tl.where(cols < N, x - mean, 0.)
88
+ var = tl.sum(xbar * xbar, axis=0) / N
89
+ else:
90
+ xbar = tl.where(cols < N, x, 0.)
91
+ var = tl.sum(xbar * xbar, axis=0) / N
92
+ rstd = 1 / tl.sqrt(var + eps)
93
+ tl.store(Rstd + row, rstd)
94
+ # Normalize and apply linear transformation
95
+ mask = cols < N
96
+ w = tl.load(W + cols, mask=mask).to(tl.float32)
97
+ if HAS_BIAS:
98
+ b = tl.load(B + cols, mask=mask).to(tl.float32)
99
+ x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
100
+ y = x_hat * w + b if HAS_BIAS else x_hat * w
101
+ if HAS_Z and NORM_BEFORE_GATE:
102
+ z = tl.load(Z + cols, mask=mask).to(tl.float32)
103
+ y *= z * tl.sigmoid(z)
104
+ # Write output
105
+ tl.store(Y + cols, y, mask=mask)
106
+
107
+
108
+ def _layer_norm_fwd(x, weight, bias, eps, z=None, out=None, group_size=None, norm_before_gate=True, is_rms_norm=False):
109
+ M, N = x.shape
110
+ if group_size is None:
111
+ group_size = N
112
+ assert N % group_size == 0
113
+ ngroups = N // group_size
114
+ assert x.stride(-1) == 1
115
+ if z is not None:
116
+ assert z.stride(-1) == 1
117
+ assert z.shape == (M, N)
118
+ assert weight.shape == (N,)
119
+ assert weight.stride(-1) == 1
120
+ if bias is not None:
121
+ assert bias.stride(-1) == 1
122
+ assert bias.shape == (N,)
123
+ # allocate output
124
+ if out is not None:
125
+ assert out.shape == x.shape
126
+ else:
127
+ out = torch.empty_like(x)
128
+ assert out.stride(-1) == 1
129
+ mean = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device) if not is_rms_norm else None
130
+ rstd = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device)
131
+ # Less than 64KB per feature: enqueue fused kernel
132
+ MAX_FUSED_SIZE = 65536 // x.element_size()
133
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))
134
+ if group_size > BLOCK_N:
135
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
136
+ # heuristics for number of warps
137
+ num_warps = min(max(BLOCK_N // 256, 1), 8)
138
+ grid = (M, ngroups)
139
+ with torch.cuda.device(x.device.index):
140
+ _layer_norm_fwd_1pass_kernel[grid](x, out, weight, bias, z, mean, rstd,
141
+ x.stride(0), out.stride(0), z.stride(0) if z is not None else 0,
142
+ M, group_size, eps,
143
+ BLOCK_N=BLOCK_N,
144
+ NORM_BEFORE_GATE=norm_before_gate,
145
+ IS_RMS_NORM=is_rms_norm,
146
+ num_warps=num_warps)
147
+ return out, mean, rstd
148
+
149
+
150
+
151
+ @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
152
+ @triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None})
153
+ @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
154
+ @triton.jit
155
+ def _layer_norm_bwd_kernel(
156
+ X, # pointer to the input
157
+ W, # pointer to the weights
158
+ B, # pointer to the biases
159
+ Z, # pointer to the other branch
160
+ Y, # pointer to the output to be recomputed
161
+ DY, # pointer to the output gradient
162
+ DX, # pointer to the input gradient
163
+ DW, # pointer to the partial sum of weights gradient
164
+ DB, # pointer to the partial sum of biases gradient
165
+ DZ, # pointer to the other branch
166
+ Mean, # pointer to the mean
167
+ Rstd, # pointer to the 1/std
168
+ stride_x_row, # how much to increase the pointer when moving by 1 row
169
+ stride_z_row,
170
+ stride_y_row,
171
+ stride_dy_row,
172
+ stride_dx_row,
173
+ stride_dz_row,
174
+ stride_dw_row,
175
+ stride_db_row,
176
+ M, # number of rows in X
177
+ N, # number of columns in X
178
+ eps, # epsilon to avoid division by zero
179
+ rows_per_program,
180
+ NORM_BEFORE_GATE: tl.constexpr,
181
+ IS_RMS_NORM: tl.constexpr,
182
+ HAS_BIAS: tl.constexpr,
183
+ HAS_Z: tl.constexpr,
184
+ RECOMPUTE_OUTPUT: tl.constexpr,
185
+ BLOCK_N: tl.constexpr,
186
+ ):
187
+ # Map the program id to the elements of X, DX, and DY it should compute.
188
+ row_block_id = tl.program_id(0)
189
+ group = tl.program_id(1)
190
+ row_start = row_block_id * rows_per_program
191
+ cols = tl.arange(0, BLOCK_N)
192
+ mask = cols < N
193
+ X += row_start * stride_x_row + group * N
194
+ if HAS_Z:
195
+ Z += row_start * stride_z_row + group * N
196
+ DZ += row_start * stride_dz_row + group * N
197
+ DY += row_start * stride_dy_row + group * N
198
+ DX += row_start * stride_dx_row + group * N
199
+ if RECOMPUTE_OUTPUT:
200
+ Y += row_start * stride_y_row + group * N
201
+ if not IS_RMS_NORM:
202
+ Mean += group * M
203
+ Rstd += group * M
204
+ W += group * N
205
+ w = tl.load(W + cols, mask=mask).to(tl.float32)
206
+ if (RECOMPUTE_OUTPUT or HAS_Z) and HAS_BIAS:
207
+ B += group * N
208
+ b = tl.load(B + cols, mask=mask, other=0.).to(tl.float32)
209
+ dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
210
+ if HAS_BIAS:
211
+ db = tl.zeros((BLOCK_N,), dtype=tl.float32)
212
+ row_end = min((row_block_id + 1) * rows_per_program, M)
213
+ for row in range(row_start, row_end):
214
+ # Load data to SRAM
215
+ x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
216
+ dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
217
+ if not IS_RMS_NORM:
218
+ mean = tl.load(Mean + row)
219
+ if HAS_Z and not NORM_BEFORE_GATE:
220
+ z = tl.load(Z + cols, mask=mask, other=0.).to(tl.float32)
221
+ x_og = x
222
+ x = x_og * z * tl.sigmoid(z)
223
+ rstd = tl.load(Rstd + row)
224
+ # Compute dx
225
+ xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
226
+ xhat = tl.where(mask, xhat, 0.)
227
+ if HAS_Z and NORM_BEFORE_GATE:
228
+ z = tl.load(Z + cols, mask=mask, other=0.).to(tl.float32)
229
+ z_sigmoid = tl.sigmoid(z)
230
+ y = xhat * w + b if HAS_BIAS else xhat * w
231
+ if RECOMPUTE_OUTPUT:
232
+ tl.store(Y + cols, y * z * z_sigmoid, mask=mask)
233
+ dz = dy * y * z_sigmoid * (1 + z * (1 - z_sigmoid))
234
+ tl.store(DZ + cols, dz, mask=mask)
235
+ dy *= z * z_sigmoid
236
+ else:
237
+ if RECOMPUTE_OUTPUT:
238
+ y = xhat * w + b if HAS_BIAS else xhat * w
239
+ tl.store(Y + cols, y, mask=mask)
240
+ wdy = w * dy
241
+ c1 = tl.sum(xhat * wdy, axis=0) / N
242
+ if not IS_RMS_NORM:
243
+ c2 = tl.sum(wdy, axis=0) / N
244
+ dx = (wdy - (xhat * c1 + c2)) * rstd
245
+ else:
246
+ dx = (wdy - xhat * c1) * rstd
247
+ dw += dy * xhat
248
+ if HAS_BIAS:
249
+ db += dy
250
+ if HAS_Z and not NORM_BEFORE_GATE:
251
+ z_sigmoid = tl.sigmoid(z)
252
+ dz = dx * x_og * z_sigmoid * (1 + z * (1 - z_sigmoid))
253
+ tl.store(DZ + cols, dz, mask=mask)
254
+ dx *= z * z_sigmoid
255
+ # Write dx
256
+ tl.store(DX + cols, dx, mask=mask)
257
+
258
+ X += stride_x_row
259
+ if HAS_Z:
260
+ Z += stride_z_row
261
+ DZ += stride_dz_row
262
+ if RECOMPUTE_OUTPUT:
263
+ Y += stride_y_row
264
+ DY += stride_dy_row
265
+ DX += stride_dx_row
266
+ tl.store(DW + row_block_id * stride_dw_row + group * N + cols, dw, mask=mask)
267
+ if HAS_BIAS:
268
+ tl.store(DB + row_block_id * stride_db_row + group * N + cols, db, mask=mask)
269
+
270
+
271
+ def _layer_norm_bwd(dy, x, weight, bias, eps, mean, rstd, z=None, group_size=None,
272
+ norm_before_gate=True, is_rms_norm=False, recompute_output=False, dz=None, out=None):
273
+ M, N = x.shape
274
+ if group_size is None:
275
+ group_size = N
276
+ assert N % group_size == 0
277
+ ngroups = N // group_size
278
+ assert x.stride(-1) == 1
279
+ assert dy.stride(-1) == 1
280
+ assert dy.shape == (M, N)
281
+ if z is not None:
282
+ assert z.stride(-1) == 1
283
+ assert z.shape == (M, N)
284
+ assert weight.shape == (N,)
285
+ assert weight.stride(-1) == 1
286
+ if bias is not None:
287
+ assert bias.stride(-1) == 1
288
+ assert bias.shape == (N,)
289
+ # allocate output
290
+ dx = torch.empty_like(x)
291
+ if dz is not None:
292
+ assert z is not None
293
+ assert dz.shape == z.shape
294
+ assert dz.stride(-1) == 1
295
+ else:
296
+ dz = torch.empty_like(z) if z is not None else None
297
+ if recompute_output:
298
+ if out is None:
299
+ out = torch.empty_like(x)
300
+ assert out.shape == x.shape
301
+
302
+ # Less than 64KB per feature: enqueue fused kernel
303
+ MAX_FUSED_SIZE = 65536 // x.element_size()
304
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))
305
+ if group_size > BLOCK_N:
306
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
307
+ # heuristics for number of warps
308
+ num_warps = min(max(BLOCK_N // 256, 1), 8)
309
+ sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
310
+ # If group size is small (e.g., 64), we're only using 1 warp. So having just 108 programs
311
+ # would limit the occupancy.
312
+ nrow_groups = math.ceil(sm_count * math.ceil(4 / num_warps) / ngroups)
313
+ _dw = torch.empty((nrow_groups, N), dtype=torch.float32, device=weight.device)
314
+ _db = torch.empty((nrow_groups, N), dtype=torch.float32, device=bias.device) if bias is not None else None
315
+ rows_per_program = math.ceil(M / nrow_groups)
316
+ grid = (nrow_groups, ngroups)
317
+ with torch.cuda.device(x.device.index):
318
+ _layer_norm_bwd_kernel[grid](x, weight, bias, z, out if recompute_output else None,
319
+ dy, dx, _dw, _db, dz, mean, rstd,
320
+ x.stride(0),
321
+ z.stride(0) if z is not None else 0,
322
+ 0 if not recompute_output else out.stride(0),
323
+ dy.stride(0), dx.stride(0),
324
+ dz.stride(0) if dz is not None else 0,
325
+ _dw.stride(0),
326
+ _db.stride(0) if _db is not None else 0,
327
+ M, group_size, eps,
328
+ rows_per_program,
329
+ BLOCK_N=BLOCK_N,
330
+ NORM_BEFORE_GATE=norm_before_gate,
331
+ IS_RMS_NORM=is_rms_norm,
332
+ num_warps=num_warps)
333
+ dw = _dw.sum(0).to(weight.dtype)
334
+ db = _db.sum(0).to(bias.dtype) if bias is not None else None
335
+ return (dx, dw, db, dz) if not recompute_output else (dx, dw, db, dz, out)
336
+
337
+
338
+ class LayerNormFn(torch.autograd.Function):
339
+
340
+ @staticmethod
341
+ def forward(ctx, x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True,
342
+ is_rms_norm=False):
343
+ """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
344
+ """
345
+
346
+ x_shape_og = x.shape
347
+ # reshape input data into 2D tensor
348
+ x = x.reshape(-1, x.shape[-1])
349
+ if x.stride(-1) != 1:
350
+ x = x.contiguous()
351
+ if z is not None:
352
+ assert z.shape == x_shape_og
353
+ z = z.reshape(-1, z.shape[-1])
354
+ if z.stride(-1) != 1:
355
+ z = z.contiguous()
356
+ weight = weight.contiguous()
357
+ if bias is not None:
358
+ bias = bias.contiguous()
359
+ y, mean, rstd = _layer_norm_fwd(x, weight, bias, eps, z=z, group_size=group_size, norm_before_gate=norm_before_gate, is_rms_norm=is_rms_norm)
360
+ ctx.save_for_backward(x, weight, bias, mean, rstd, z)
361
+ ctx.x_shape_og = x_shape_og
362
+ ctx.eps = eps
363
+ ctx.group_size = group_size
364
+ ctx.norm_before_gate = norm_before_gate
365
+ ctx.is_rms_norm = is_rms_norm
366
+ return y.reshape(x_shape_og)
367
+
368
+ @staticmethod
369
+ def backward(ctx, dy):
370
+ x, weight, bias, mean, rstd, z = ctx.saved_tensors
371
+ dy = dy.reshape(-1, dy.shape[-1])
372
+ if dy.stride(-1) != 1:
373
+ dy = dy.contiguous()
374
+ assert dy.shape == x.shape
375
+ dx, dw, db, dz = _layer_norm_bwd(dy, x, weight, bias, ctx.eps, mean, rstd, z, ctx.group_size,
376
+ ctx.norm_before_gate, ctx.is_rms_norm)
377
+ return dx.reshape(ctx.x_shape_og), dw, db, dz.reshape(ctx.x_shape_og) if dz is not None else None, None, None, None, None
378
+
379
+
380
+ def layernorm_fn(x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True, is_rms_norm=False):
381
+ return LayerNormFn.apply(x, weight, bias, z, eps, group_size, norm_before_gate, is_rms_norm)
382
+
383
+
384
+ def rmsnorm_fn(x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True):
385
+ return LayerNormFn.apply(x, weight, bias, z, eps, group_size, norm_before_gate, True)
386
+
387
+
388
+ class LayerNorm(torch.nn.Module):
389
+
390
+ def __init__(self, hidden_size, eps=1e-5, group_size=None, norm_before_gate=True, device=None, dtype=None):
391
+ """If group_size is not None, we do GroupNorm with each group having group_size elements.
392
+ group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
393
+ """
394
+
395
+ factory_kwargs = {"device": device, "dtype": dtype}
396
+ super().__init__()
397
+ self.eps = eps
398
+ self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
399
+ self.bias = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
400
+ self.group_size = group_size
401
+ self.norm_before_gate = norm_before_gate
402
+ self.reset_parameters()
403
+
404
+ def reset_parameters(self):
405
+ torch.nn.init.ones_(self.weight)
406
+ torch.nn.init.zeros_(self.bias)
407
+
408
+ def forward(self, x, z=None):
409
+ """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
410
+ """
411
+ return layernorm_fn(x, self.weight, self.bias, z=z, group_size=self.group_size, eps=self.eps,
412
+ norm_before_gate=self.norm_before_gate)
413
+
414
+
415
+ class RMSNorm(torch.nn.Module):
416
+
417
+ def __init__(self, hidden_size, eps=1e-5, group_size=None, norm_before_gate=True, device=None, dtype=None):
418
+ """If group_size is not None, we do GroupNorm with each group having group_size elements.
419
+ group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
420
+ """
421
+ factory_kwargs = {"device": device, "dtype": dtype}
422
+ super().__init__()
423
+ self.eps = eps
424
+ self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
425
+ self.register_parameter("bias", None)
426
+ self.group_size = group_size
427
+ self.norm_before_gate = norm_before_gate
428
+ self.reset_parameters()
429
+
430
+ def reset_parameters(self):
431
+ torch.nn.init.ones_(self.weight)
432
+
433
+ def forward(self, x, z=None):
434
+ """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
435
+ """
436
+ return rmsnorm_fn(x, self.weight, self.bias, z=z, eps=self.eps, group_size=self.group_size,
437
+ norm_before_gate=self.norm_before_gate)
build/torch210-cxx11-cu126-x86_64-linux/ops/triton/selective_state_update.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao, Albert Gu.
2
+
3
+ """We want triton==2.1.0 or triton==2.2.0 or triton==2.3.0 for this
4
+ """
5
+
6
+ import math
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+ import triton
11
+ import triton.language as tl
12
+
13
+ from einops import rearrange, repeat
14
+
15
+ from .softplus import softplus
16
+
17
+
18
+ @triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None})
19
+ @triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None})
20
+ @triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None})
21
+ @triton.heuristics({"HAS_STATE_BATCH_INDICES": lambda args: args["state_batch_indices_ptr"] is not None})
22
+ @triton.heuristics({"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])})
23
+ @triton.jit
24
+ def _selective_scan_update_kernel(
25
+ # Pointers to matrices
26
+ state_ptr, x_ptr, dt_ptr, dt_bias_ptr, A_ptr, B_ptr, C_ptr, D_ptr, z_ptr, out_ptr, state_batch_indices_ptr,
27
+ # Matrix dimensions
28
+ batch, nheads, dim, dstate, nheads_ngroups_ratio,
29
+ # Strides
30
+ stride_state_batch, stride_state_head, stride_state_dim, stride_state_dstate,
31
+ stride_x_batch, stride_x_head, stride_x_dim,
32
+ stride_dt_batch, stride_dt_head, stride_dt_dim,
33
+ stride_dt_bias_head, stride_dt_bias_dim,
34
+ stride_A_head, stride_A_dim, stride_A_dstate,
35
+ stride_B_batch, stride_B_group, stride_B_dstate,
36
+ stride_C_batch, stride_C_group, stride_C_dstate,
37
+ stride_D_head, stride_D_dim,
38
+ stride_z_batch, stride_z_head, stride_z_dim,
39
+ stride_out_batch, stride_out_head, stride_out_dim,
40
+ # Meta-parameters
41
+ DT_SOFTPLUS: tl.constexpr,
42
+ TIE_HDIM: tl.constexpr,
43
+ BLOCK_SIZE_M: tl.constexpr,
44
+ HAS_DT_BIAS: tl.constexpr,
45
+ HAS_D: tl.constexpr,
46
+ HAS_Z: tl.constexpr,
47
+ HAS_STATE_BATCH_INDICES: tl.constexpr,
48
+ BLOCK_SIZE_DSTATE: tl.constexpr,
49
+ ):
50
+ pid_m = tl.program_id(axis=0)
51
+ pid_b = tl.program_id(axis=1)
52
+ pid_h = tl.program_id(axis=2)
53
+
54
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
55
+ out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head
56
+ out_ptrs = out_ptr + offs_m * stride_out_dim
57
+
58
+ if HAS_STATE_BATCH_INDICES:
59
+ state_batch_indices_ptr += pid_b
60
+ state_batch_idx = tl.load(state_batch_indices_ptr)
61
+ # Skip padding tokens
62
+ if state_batch_idx < 0:
63
+ tl.store(out_ptrs, 0.0, mask=offs_m < dim)
64
+ return
65
+ state_ptr += state_batch_idx * stride_state_batch + pid_h * stride_state_head
66
+ else:
67
+ state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head
68
+
69
+ x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head
70
+ dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head
71
+ if HAS_DT_BIAS:
72
+ dt_bias_ptr += pid_h * stride_dt_bias_head
73
+ A_ptr += pid_h * stride_A_head
74
+ B_ptr += pid_b * stride_B_batch + (pid_h // nheads_ngroups_ratio) * stride_B_group
75
+ C_ptr += pid_b * stride_C_batch + (pid_h // nheads_ngroups_ratio) * stride_C_group
76
+ if HAS_Z:
77
+ z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head
78
+
79
+ offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)
80
+ state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate)
81
+ x_ptrs = x_ptr + offs_m * stride_x_dim
82
+ dt_ptrs = dt_ptr + offs_m * stride_dt_dim
83
+ if HAS_DT_BIAS:
84
+ dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim
85
+ if HAS_D:
86
+ D_ptr += pid_h * stride_D_head
87
+ A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate)
88
+ B_ptrs = B_ptr + offs_n * stride_B_dstate
89
+ C_ptrs = C_ptr + offs_n * stride_C_dstate
90
+ if HAS_D:
91
+ D_ptrs = D_ptr + offs_m * stride_D_dim
92
+ if HAS_Z:
93
+ z_ptrs = z_ptr + offs_m * stride_z_dim
94
+
95
+ state = tl.load(state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0)
96
+ x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
97
+ if not TIE_HDIM:
98
+ dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
99
+ if HAS_DT_BIAS:
100
+ dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
101
+ if DT_SOFTPLUS:
102
+ dt = tl.where(dt <= 20.0, softplus(dt), dt)
103
+ A = tl.load(A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)
104
+ dA = tl.exp(A * dt[:, None])
105
+ else:
106
+ dt = tl.load(dt_ptr).to(tl.float32)
107
+ if HAS_DT_BIAS:
108
+ dt += tl.load(dt_bias_ptr).to(tl.float32)
109
+ if DT_SOFTPLUS:
110
+ dt = tl.where(dt <= 20.0, softplus(dt), dt)
111
+ A = tl.load(A_ptr).to(tl.float32)
112
+ dA = tl.exp(A * dt) # scalar, not a matrix
113
+
114
+ B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
115
+ C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
116
+ if HAS_D:
117
+ D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
118
+ if HAS_Z:
119
+ z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
120
+
121
+ if not TIE_HDIM:
122
+ dB = B[None, :] * dt[:, None]
123
+ else:
124
+ dB = B * dt # vector of size (dstate,)
125
+ state = state * dA + dB * x[:, None]
126
+ tl.store(state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate))
127
+ out = tl.sum(state * C[None, :], axis=1)
128
+ if HAS_D:
129
+ out += x * D
130
+ if HAS_Z:
131
+ out *= z * tl.sigmoid(z)
132
+ tl.store(out_ptrs, out, mask=offs_m < dim)
133
+
134
+
135
+ def selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False,
136
+ state_batch_indices=None):
137
+ """
138
+ Argument:
139
+ state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
140
+ x: (batch, dim) or (batch, nheads, dim)
141
+ dt: (batch, dim) or (batch, nheads, dim)
142
+ A: (dim, dstate) or (nheads, dim, dstate)
143
+ B: (batch, dstate) or (batch, ngroups, dstate)
144
+ C: (batch, dstate) or (batch, ngroups, dstate)
145
+ D: (dim,) or (nheads, dim)
146
+ z: (batch, dim) or (batch, nheads, dim)
147
+ dt_bias: (dim,) or (nheads, dim)
148
+ Return:
149
+ out: (batch, dim) or (batch, nheads, dim)
150
+ """
151
+ has_heads = state.dim() > 3
152
+ if state.dim() == 3:
153
+ state = state.unsqueeze(1)
154
+ if x.dim() == 2:
155
+ x = x.unsqueeze(1)
156
+ if dt.dim() == 2:
157
+ dt = dt.unsqueeze(1)
158
+ if A.dim() == 2:
159
+ A = A.unsqueeze(0)
160
+ if B.dim() == 2:
161
+ B = B.unsqueeze(1)
162
+ if C.dim() == 2:
163
+ C = C.unsqueeze(1)
164
+ if D is not None and D.dim() == 1:
165
+ D = D.unsqueeze(0)
166
+ if z is not None and z.dim() == 2:
167
+ z = z.unsqueeze(1)
168
+ if dt_bias is not None and dt_bias.dim() == 1:
169
+ dt_bias = dt_bias.unsqueeze(0)
170
+ _, nheads, dim, dstate = state.shape
171
+ batch = x.shape[0]
172
+ if x.shape != (batch, nheads, dim):
173
+ print(f"{state.shape} {x.shape} {batch} {nheads} {dim}")
174
+ assert x.shape == (batch, nheads, dim)
175
+ assert dt.shape == x.shape
176
+ assert A.shape == (nheads, dim, dstate)
177
+ ngroups = B.shape[1]
178
+ assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
179
+ assert B.shape == (batch, ngroups, dstate)
180
+ assert C.shape == B.shape
181
+ if D is not None:
182
+ assert D.shape == (nheads, dim)
183
+ if z is not None:
184
+ assert z.shape == x.shape
185
+ if dt_bias is not None:
186
+ assert dt_bias.shape == (nheads, dim)
187
+ if state_batch_indices is not None:
188
+ assert state_batch_indices.shape == (batch,)
189
+ out = torch.empty_like(x)
190
+ grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads)
191
+ z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0))
192
+ # We don't want autotune since it will overwrite the state
193
+ # We instead tune by hand.
194
+ BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16
195
+ else ((16, 4) if dstate <= 32 else
196
+ ((8, 4) if dstate <= 64 else
197
+ ((4, 4) if dstate <= 128 else
198
+ ((4, 8))))))
199
+ tie_hdim = A.stride(-1) == 0 and A.stride(-2) == 0 and dt.stride(-1) == 0 and dt_bias.stride(-1) == 0
200
+ with torch.cuda.device(x.device.index):
201
+ _selective_scan_update_kernel[grid](
202
+ state, x, dt, dt_bias, A, B, C, D, z, out, state_batch_indices,
203
+ batch, nheads, dim, dstate, nheads // ngroups,
204
+ state.stride(0), state.stride(1), state.stride(2), state.stride(3),
205
+ x.stride(0), x.stride(1), x.stride(2),
206
+ dt.stride(0), dt.stride(1), dt.stride(2),
207
+ *(dt_bias.stride(0), dt_bias.stride(1)) if dt_bias is not None else 0,
208
+ A.stride(0), A.stride(1), A.stride(2),
209
+ B.stride(0), B.stride(1), B.stride(2),
210
+ C.stride(0), C.stride(1), C.stride(2),
211
+ *(D.stride(0), D.stride(1)) if D is not None else 0,
212
+ z_strides[0], z_strides[1], z_strides[2],
213
+ out.stride(0), out.stride(1), out.stride(2),
214
+ dt_softplus,
215
+ tie_hdim,
216
+ BLOCK_SIZE_M,
217
+ num_warps=num_warps,
218
+ )
219
+ if not has_heads:
220
+ out = out.squeeze(1)
221
+ return out
222
+
223
+
224
+ def selective_state_update_ref(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False):
225
+ """
226
+ Argument:
227
+ state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
228
+ x: (batch, dim) or (batch, nheads, dim)
229
+ dt: (batch, dim) or (batch, nheads, dim)
230
+ A: (dim, dstate) or (nheads, dim, dstate)
231
+ B: (batch, dstate) or (batch, ngroups, dstate)
232
+ C: (batch, dstate) or (batch, ngroups, dstate)
233
+ D: (dim,) or (nheads, dim)
234
+ z: (batch, dim) or (batch, nheads, dim)
235
+ dt_bias: (dim,) or (nheads, dim)
236
+ Return:
237
+ out: (batch, dim) or (batch, nheads, dim)
238
+ """
239
+ has_heads = state.dim() > 3
240
+ if state.dim() == 3:
241
+ state = state.unsqueeze(1)
242
+ if x.dim() == 2:
243
+ x = x.unsqueeze(1)
244
+ if dt.dim() == 2:
245
+ dt = dt.unsqueeze(1)
246
+ if A.dim() == 2:
247
+ A = A.unsqueeze(0)
248
+ if B.dim() == 2:
249
+ B = B.unsqueeze(1)
250
+ if C.dim() == 2:
251
+ C = C.unsqueeze(1)
252
+ if D is not None and D.dim() == 1:
253
+ D = D.unsqueeze(0)
254
+ if z is not None and z.dim() == 2:
255
+ z = z.unsqueeze(1)
256
+ if dt_bias is not None and dt_bias.dim() == 1:
257
+ dt_bias = dt_bias.unsqueeze(0)
258
+ batch, nheads, dim, dstate = state.shape
259
+ assert x.shape == (batch, nheads, dim)
260
+ assert dt.shape == x.shape
261
+ assert A.shape == (nheads, dim, dstate)
262
+ ngroups = B.shape[1]
263
+ assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
264
+ assert B.shape == (batch, ngroups, dstate)
265
+ assert C.shape == B.shape
266
+ if D is not None:
267
+ assert D.shape == (nheads, dim)
268
+ if z is not None:
269
+ assert z.shape == x.shape
270
+ if dt_bias is not None:
271
+ assert dt_bias.shape == (nheads, dim)
272
+ dt = dt + dt_bias
273
+ dt = F.softplus(dt) if dt_softplus else dt
274
+ dA = torch.exp(rearrange(dt, "b h d -> b h d 1") * A) # (batch, nheads, dim, dstate)
275
+ B = repeat(B, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate)
276
+ C = repeat(C, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate)
277
+ dB = rearrange(dt, "b h d -> b h d 1") * rearrange(B, "b h n -> b h 1 n") # (batch, nheads, dim, dstate)
278
+ state.copy_(state * dA + dB * rearrange(x, "b h d -> b h d 1")) # (batch, dim, dstate
279
+ out = torch.einsum("bhdn,bhn->bhd", state.to(C.dtype), C)
280
+ if D is not None:
281
+ out += (x * D).to(out.dtype)
282
+ out = (out if z is None else out * F.silu(z)).to(x.dtype)
283
+ if not has_heads:
284
+ out = out.squeeze(1)
285
+ return out
build/torch210-cxx11-cu126-x86_64-linux/ops/triton/softplus.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import triton
2
+ import triton.language as tl
3
+ from packaging import version
4
+
5
+ TRITON3 = version.parse(triton.__version__) >= version.parse("3.0.0")
6
+
7
+
8
+ if TRITON3:
9
+ @triton.jit
10
+ def softplus(dt):
11
+ return tl.math.log(tl.math.exp(dt) + 1)
12
+ else:
13
+ @triton.jit
14
+ def softplus(dt):
15
+ return tl.math.log1p(tl.exp(dt))
build/torch210-cxx11-cu126-x86_64-linux/ops/triton/ssd_bmm.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao, Albert Gu.
2
+
3
+ """We want triton==2.1.0 or 2.2.0 for this
4
+ """
5
+
6
+ import math
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+ import triton
11
+ import triton.language as tl
12
+
13
+ from einops import rearrange, repeat
14
+
15
+
16
+ def init_to_zero(names):
17
+ return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None]
18
+
19
+
20
+ @triton.autotune(
21
+ configs=[
22
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8),
23
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
24
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
25
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
26
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
27
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
28
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
29
+ triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
30
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2),
31
+ ],
32
+ key=['chunk_size', 'K', 'IS_CAUSAL'],
33
+ )
34
+ @triton.jit
35
+ def _bmm_chunk_fwd_kernel(
36
+ # Pointers to matrices
37
+ a_ptr, b_ptr, out_ptr, seq_idx_ptr,
38
+ # Matrix dimensions
39
+ seqlen, chunk_size, K, ngroups,
40
+ stride_a_batch, stride_a_seqlen, stride_a_head, stride_ak,
41
+ stride_b_batch, stride_b_seqlen, stride_b_head, stride_bk,
42
+ stride_out_batch, stride_out_chunk, stride_out_head, stride_outm, stride_outn,
43
+ stride_seq_idx_batch, stride_seq_idx_seqlen,
44
+ # Meta-parameters
45
+ IS_CAUSAL: tl.constexpr,
46
+ dot_dtype: tl.constexpr,
47
+ HAS_SEQ_IDX: tl.constexpr,
48
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
49
+ ):
50
+ pid_b = tl.program_id(axis=1)
51
+ pid_ch = tl.program_id(axis=2)
52
+ pid_c = pid_ch // ngroups
53
+ pid_h = pid_ch - pid_c * ngroups
54
+ num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N)
55
+ pid_m = tl.program_id(axis=0) // num_pid_n
56
+ pid_n = tl.program_id(axis=0) % num_pid_n
57
+ if IS_CAUSAL:
58
+ if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M:
59
+ return
60
+ a_ptr += pid_b * stride_a_batch + pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head
61
+ b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + pid_h * stride_b_head
62
+ if HAS_SEQ_IDX:
63
+ seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
64
+
65
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
66
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
67
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
68
+ a_ptrs = a_ptr + (offs_m[:, None] * stride_a_seqlen + offs_k[None, :] * stride_ak)
69
+ b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_b_seqlen)
70
+ chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
71
+
72
+ acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
73
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
74
+ a = tl.load(a_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < K - k * BLOCK_SIZE_K), other=0.0).to(dot_dtype)
75
+ b = tl.load(b_ptrs, mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) & (offs_n[None, :] < chunk_size_limit), other=0.0).to(dot_dtype)
76
+ acc += tl.dot(a, b)
77
+ a_ptrs += BLOCK_SIZE_K * stride_ak
78
+ b_ptrs += BLOCK_SIZE_K * stride_bk
79
+
80
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
81
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
82
+ if HAS_SEQ_IDX:
83
+ chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
84
+ seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
85
+ seq_idx_n = tl.load(seq_idx_ptr + offs_n * stride_seq_idx_seqlen, mask=offs_n < chunk_size_limit, other=-2)
86
+ acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0)
87
+ out = acc.to(out_ptr.dtype.element_ty)
88
+
89
+ out_ptr += pid_b * stride_out_batch + pid_c * stride_out_chunk + pid_h * stride_out_head
90
+ out_ptrs = out_ptr + (stride_outm * offs_m[:, None] + offs_n[None, :] * stride_outn)
91
+ tl.store(out_ptrs, out, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size))
92
+
93
+
94
+ @triton.autotune(
95
+ configs=[
96
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_CS': 64}, num_stages=3, num_warps=8),
97
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),
98
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),
99
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),
100
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),
101
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),
102
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_CS': 32}, num_stages=5, num_warps=2),
103
+ triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=5, num_warps=2),
104
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=2),
105
+ ],
106
+ key=['chunk_size', 'K'],
107
+ )
108
+ @triton.jit
109
+ def _bmm_chunk_bwd_kernel(
110
+ # Pointers to matrices
111
+ a_ptr, dout_ptr, db_ptr, res_ptr,
112
+ # Matrix dimensions
113
+ seqlen, chunk_size, K, ngroups,
114
+ stride_a_batch, stride_a_seqlen, stride_a_head, stride_ak,
115
+ stride_dout_batch, stride_dout_chunk, stride_dout_head, stride_dout_csize_m, stride_dout_csize_n,
116
+ stride_db_batch, stride_db_seqlen, stride_db_head, stride_db_k,
117
+ stride_res_batch, stride_res_seqlen, stride_res_head, stride_res_k,
118
+ # Meta-parameters
119
+ dot_dtype: tl.constexpr,
120
+ HAS_RESIDUAL: tl.constexpr,
121
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_CS: tl.constexpr,
122
+ ):
123
+ pid_b = tl.program_id(axis=1)
124
+ pid_ch = tl.program_id(axis=2)
125
+ pid_c = pid_ch // ngroups
126
+ pid_h = pid_ch - pid_c * ngroups
127
+ num_pid_n = tl.cdiv(K, BLOCK_SIZE_N)
128
+ pid_m = tl.program_id(axis=0) // num_pid_n
129
+ pid_n = tl.program_id(axis=0) % num_pid_n
130
+
131
+ a_ptr += pid_b * stride_a_batch + pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head
132
+ dout_ptr += pid_b * stride_dout_batch + pid_c * stride_dout_chunk + pid_h * stride_dout_head
133
+
134
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
135
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
136
+ offs_cs = tl.arange(0, BLOCK_SIZE_CS)
137
+ dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_csize_n + offs_cs[None, :] * stride_dout_csize_m)
138
+ a_ptrs = a_ptr + (offs_cs[:, None] * stride_a_seqlen + offs_n[None, :] * stride_ak)
139
+ chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
140
+
141
+ acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
142
+ for cs in range(0, tl.cdiv(chunk_size_limit, BLOCK_SIZE_CS)):
143
+ dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_cs[None, :] < chunk_size_limit - cs * BLOCK_SIZE_CS), other=0.0).to(dot_dtype)
144
+ a = tl.load(a_ptrs, mask=(offs_cs[:, None] < chunk_size_limit - cs * BLOCK_SIZE_CS) & (offs_n[None, :] < K), other=0.0).to(dot_dtype)
145
+ acc += tl.dot(dout, a)
146
+ dout_ptrs += BLOCK_SIZE_CS * stride_dout_csize_m
147
+ a_ptrs += BLOCK_SIZE_CS * stride_a_seqlen
148
+
149
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
150
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
151
+ if HAS_RESIDUAL:
152
+ res_ptr += pid_b * stride_res_batch + pid_c * chunk_size * stride_res_seqlen + pid_h * stride_res_head
153
+ res_ptrs = res_ptr + (offs_m[:, None] * stride_res_seqlen + offs_n[None, :] * stride_res_k)
154
+ res = tl.load(res_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < K)).to(tl.float32)
155
+ acc += res
156
+ db = acc.to(db_ptr.dtype.element_ty)
157
+
158
+ db_ptr += pid_b * stride_db_batch + pid_c * chunk_size * stride_db_seqlen + pid_h * stride_db_head
159
+ db_ptrs = db_ptr + (offs_m[:, None] * stride_db_seqlen + offs_n[None, :] * stride_db_k)
160
+ tl.store(db_ptrs, db, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < K))
161
+
162
+
163
+ def _bmm_chunk_fwd(a, b, chunk_size, seq_idx=None, causal=False, output_dtype=None):
164
+ """
165
+ Argument:
166
+ a: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
167
+ b: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
168
+ seq_idx: (batch, seqlen) or None. out[i, j] for seq_idx[i] != seq_idx[j] will be zeroed out.
169
+ causal: if True, then out[i, j] for i > j will be arbitrary, only out[i, j] for i <= j are
170
+ guaranteed to be correct.
171
+ Return:
172
+ out: (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, ngroups, chunk_size, chunk_size)
173
+ """
174
+ # Check constraints.
175
+ has_groups = a.dim() == 4
176
+ if not has_groups:
177
+ batch, seqlen, k = a.shape
178
+ else:
179
+ batch, seqlen, ngroups, k = a.shape
180
+ assert b.shape == a.shape
181
+ if seq_idx is not None:
182
+ assert seq_idx.shape == (batch, seqlen)
183
+ if a.stride(-1) != 1 and a.stride(1) != 1:
184
+ a = a.contiguous()
185
+ if b.stride(-1) != 1 and b.stride(1) != 1:
186
+ b = b.contiguous()
187
+ nchunks = math.ceil(seqlen / chunk_size)
188
+ # Allocates output.
189
+ out_dtype = a.dtype if output_dtype is None else output_dtype
190
+ out = torch.empty((batch, nchunks, chunk_size, chunk_size) if not has_groups else (batch, nchunks, ngroups, chunk_size, chunk_size),
191
+ device=a.device, dtype=out_dtype)
192
+ dot_dtype = (tl.bfloat16 if a.dtype == torch.bfloat16 or b.dtype == torch.bfloat16 else
193
+ (tl.float16 if a.dtype == torch.float16 or b.dtype == torch.float16 else tl.float32))
194
+ grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(chunk_size, META['BLOCK_SIZE_N']),
195
+ batch, nchunks if not has_groups else nchunks * ngroups)
196
+ with torch.cuda.device(a.device.index):
197
+ _bmm_chunk_fwd_kernel[grid](
198
+ a, b, out, seq_idx,
199
+ seqlen, chunk_size, k, ngroups if has_groups else 1,
200
+ a.stride(0), a.stride(1), 0 if not has_groups else a.stride(2), a.stride(-1),
201
+ b.stride(0), b.stride(1), 0 if not has_groups else b.stride(2), b.stride(-1),
202
+ out.stride(0), out.stride(1), 0 if not has_groups else out.stride(2), out.stride(-2), out.stride(-1),
203
+ *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
204
+ causal,
205
+ dot_dtype,
206
+ HAS_SEQ_IDX=seq_idx is not None,
207
+ )
208
+ return out
209
+
210
+
211
+ def _bmm_chunk_bwd(a, dout, residual=None, out=None):
212
+ """
213
+ Argument:
214
+ a: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
215
+ dout: (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, ngroups, chunk_size, chunk_size)
216
+ residual: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
217
+ Return:
218
+ out: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
219
+
220
+ If there was seq_idx in the fwd pass, then dout[i, j] for seq_idx[i] != seq_idx[j] should already be
221
+ zeroed out before calling this function.
222
+ """
223
+ # Check constraints.
224
+ has_groups = a.dim() == 4
225
+ if not has_groups:
226
+ batch, seqlen, k = a.shape
227
+ else:
228
+ batch, seqlen, ngroups, k = a.shape
229
+ nchunks, chunk_size = dout.shape[1], dout.shape[-1]
230
+ if a.stride(-1) != 1 and a.stride(-2) != 1:
231
+ a = a.contiguous()
232
+ if dout.stride(-1) != 1 and dout.stride(-2) != 1:
233
+ dout = dout.contiguous()
234
+ if residual is not None:
235
+ assert residual.shape == (batch, seqlen, k) if not has_groups else (batch, seqlen, ngroups, k)
236
+ if residual.stride(-1) != 1 and residual.stride(1) != 1:
237
+ residual = residual.contiguous()
238
+ # Allocates output.
239
+ if out is not None:
240
+ assert out.shape == a.shape
241
+ assert out.stride(-1) == 1 or out.stride(1) == 1
242
+ else:
243
+ out = torch.empty_like(a)
244
+ dot_dtype = (tl.bfloat16 if a.dtype == torch.bfloat16 or dout.dtype == torch.bfloat16 else
245
+ (tl.float16 if a.dtype == torch.float16 or dout.dtype == torch.float16 else tl.float32))
246
+ grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(k, META['BLOCK_SIZE_N']), batch,
247
+ nchunks if not has_groups else nchunks * ngroups)
248
+ residual_strides = ((residual.stride(0), residual.stride(1), 0 if not has_groups else residual.stride(2),
249
+ residual.stride(-1))
250
+ if residual is not None else (0, 0, 0, 0))
251
+ with torch.cuda.device(a.device.index):
252
+ _bmm_chunk_bwd_kernel[grid](
253
+ a, dout, out, residual,
254
+ seqlen, chunk_size, k, ngroups if has_groups else 1,
255
+ a.stride(0), a.stride(1), 0 if not has_groups else a.stride(2), a.stride(-1),
256
+ dout.stride(0), dout.stride(1), 0 if not has_groups else dout.stride(2), dout.stride(-2), dout.stride(-1),
257
+ out.stride(0), out.stride(1), 0 if not has_groups else out.stride(2), out.stride(-1),
258
+ residual_strides[0], residual_strides[1], residual_strides[2], residual_strides[3],
259
+ dot_dtype,
260
+ HAS_RESIDUAL=residual is not None,
261
+ )
262
+ return out
build/torch210-cxx11-cu126-x86_64-linux/ops/triton/ssd_chunk_scan.py ADDED
The diff for this file is too large to render. See raw diff
 
build/torch210-cxx11-cu126-x86_64-linux/ops/triton/ssd_chunk_state.py ADDED
@@ -0,0 +1,997 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao, Albert Gu.
2
+
3
+ """We want triton==2.1.0 or 2.2.0 for this
4
+ """
5
+
6
+ import math
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+ import triton
11
+ import triton.language as tl
12
+
13
+ from einops import rearrange, repeat
14
+
15
+ from .softplus import softplus
16
+
17
+
18
+ def init_to_zero(names):
19
+ return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None]
20
+
21
+ @triton.autotune(
22
+ configs=[
23
+ triton.Config({'BLOCK_SIZE_H': 1}),
24
+ triton.Config({'BLOCK_SIZE_H': 2}),
25
+ triton.Config({'BLOCK_SIZE_H': 4}),
26
+ triton.Config({'BLOCK_SIZE_H': 8}),
27
+ triton.Config({'BLOCK_SIZE_H': 16}),
28
+ triton.Config({'BLOCK_SIZE_H': 32}),
29
+ triton.Config({'BLOCK_SIZE_H': 64}),
30
+ ],
31
+ key=['chunk_size', 'nheads'],
32
+ )
33
+ @triton.jit
34
+ def _chunk_cumsum_fwd_kernel(
35
+ # Pointers to matrices
36
+ dt_ptr, A_ptr, dt_bias_ptr, dt_out_ptr, dA_cumsum_ptr,
37
+ # Matrix dimension
38
+ batch, seqlen, nheads, chunk_size,
39
+ dt_min, dt_max,
40
+ # Strides
41
+ stride_dt_batch, stride_dt_seqlen, stride_dt_head,
42
+ stride_A_head,
43
+ stride_dt_bias_head,
44
+ stride_dt_out_batch, stride_dt_out_chunk, stride_dt_out_head, stride_dt_out_csize,
45
+ stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
46
+ # Meta-parameters
47
+ DT_SOFTPLUS: tl.constexpr,
48
+ HAS_DT_BIAS: tl.constexpr,
49
+ BLOCK_SIZE_H: tl.constexpr, BLOCK_SIZE_CHUNK: tl.constexpr,
50
+ ):
51
+ pid_b = tl.program_id(axis=0)
52
+ pid_c = tl.program_id(axis=1)
53
+ pid_h = tl.program_id(axis=2)
54
+ dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen
55
+ dt_out_ptr += pid_b * stride_dt_out_batch + pid_c * stride_dt_out_chunk
56
+ dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk
57
+
58
+ offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)
59
+ offs_c = tl.arange(0, BLOCK_SIZE_CHUNK)
60
+ dt_ptrs = dt_ptr + (offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen)
61
+ A_ptrs = A_ptr + offs_h * stride_A_head
62
+ dt_out_ptrs = dt_out_ptr + (offs_h[:, None] * stride_dt_out_head + offs_c[None, :] * stride_dt_out_csize)
63
+ dA_cs_ptrs = dA_cumsum_ptr + (offs_h[:, None] * stride_dA_cs_head + offs_c[None, :] * stride_dA_cs_csize)
64
+ chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
65
+
66
+ dt = tl.load(dt_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0).to(tl.float32)
67
+ if HAS_DT_BIAS:
68
+ dt_bias = tl.load(dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0).to(tl.float32)
69
+ dt += dt_bias[:, None]
70
+ if DT_SOFTPLUS:
71
+ dt = tl.where(dt <= 20.0, softplus(dt), dt)
72
+ # As of Triton 2.2.0, tl.clamp is not available yet
73
+ # dt = tl.clamp(dt, dt_min, dt_max)
74
+ dt = tl.minimum(tl.maximum(dt, dt_min), dt_max)
75
+ dt = tl.where((offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0)
76
+ tl.store(dt_out_ptrs, dt, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size))
77
+ A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32)
78
+ dA = dt * A[:, None]
79
+ dA_cs = tl.cumsum(dA, axis=1)
80
+ tl.store(dA_cs_ptrs, dA_cs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size))
81
+
82
+
83
+ @triton.autotune(
84
+ configs=[
85
+ triton.Config({'BLOCK_SIZE_H': 1}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])),
86
+ triton.Config({'BLOCK_SIZE_H': 2}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])),
87
+ triton.Config({'BLOCK_SIZE_H': 4}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])),
88
+ triton.Config({'BLOCK_SIZE_H': 8}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])),
89
+ triton.Config({'BLOCK_SIZE_H': 16}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])),
90
+ triton.Config({'BLOCK_SIZE_H': 32}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])),
91
+ triton.Config({'BLOCK_SIZE_H': 64}, pre_hook=init_to_zero(["dA_ptr", "ddt_bias_ptr"])),
92
+ ],
93
+ key=['chunk_size', 'nheads'],
94
+ )
95
+ @triton.jit
96
+ def _chunk_cumsum_bwd_kernel(
97
+ # Pointers to matrices
98
+ ddA_ptr, ddt_out_ptr, dt_ptr, A_ptr, dt_bias_ptr,
99
+ ddt_ptr, dA_ptr, ddt_bias_ptr,
100
+ # Matrix dimensions
101
+ batch, seqlen, nheads, chunk_size,
102
+ dt_min, dt_max,
103
+ # Strides
104
+ stride_ddA_batch, stride_ddA_chunk, stride_ddA_head, stride_ddA_csize,
105
+ stride_ddt_out_batch, stride_ddt_out_chunk, stride_ddt_out_head, stride_ddt_out_csize,
106
+ stride_dt_batch, stride_dt_seqlen, stride_dt_head,
107
+ stride_A_head,
108
+ stride_dt_bias_head,
109
+ stride_ddt_batch, stride_ddt_seqlen, stride_ddt_head,
110
+ stride_dA_head,
111
+ stride_ddt_bias_head,
112
+ # Meta-parameters
113
+ DT_SOFTPLUS: tl.constexpr,
114
+ HAS_DT_BIAS: tl.constexpr,
115
+ BLOCK_SIZE_H: tl.constexpr, BLOCK_SIZE_CHUNK: tl.constexpr,
116
+ ):
117
+ pid_b = tl.program_id(axis=0)
118
+ pid_c = tl.program_id(axis=1)
119
+ pid_h = tl.program_id(axis=2)
120
+ ddt_out_ptr += pid_b * stride_ddt_out_batch + pid_c * stride_ddt_out_chunk
121
+ ddA_ptr += pid_b * stride_ddA_batch + pid_c * stride_ddA_chunk
122
+ dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen
123
+ ddt_ptr += pid_b * stride_ddt_batch + pid_c * chunk_size * stride_ddt_seqlen
124
+
125
+ offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)
126
+ offs_c = tl.arange(0, BLOCK_SIZE_CHUNK)
127
+ ddt_out_ptrs = ddt_out_ptr + (offs_h[:, None] * stride_ddt_out_head + offs_c[None, :] * stride_ddt_out_csize)
128
+ ddA_ptrs = ddA_ptr + (offs_h[:, None] * stride_ddA_head + offs_c[None, :] * stride_ddA_csize)
129
+ dt_ptrs = dt_ptr + (offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen)
130
+ ddt_ptrs = ddt_ptr + (offs_h[:, None] * stride_ddt_head + offs_c[None, :] * stride_ddt_seqlen)
131
+ A_ptrs = A_ptr + offs_h * stride_A_head
132
+ chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
133
+
134
+ ddA = tl.load(ddA_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0).to(tl.float32)
135
+ ddt_out = tl.load(ddt_out_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0).to(tl.float32)
136
+ A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32)
137
+ ddt = ddA * A[:, None] + ddt_out
138
+ dt = tl.load(dt_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0).to(tl.float32)
139
+ if HAS_DT_BIAS:
140
+ dt_bias = tl.load(dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0).to(tl.float32)
141
+ dt += dt_bias[:, None]
142
+ if DT_SOFTPLUS:
143
+ dt_presoftplus = dt
144
+ dt = tl.where(dt <= 20.0, softplus(dt), dt)
145
+ clamp_mask = (dt < dt_min) | (dt > dt_max)
146
+ # As of Triton 2.2.0, tl.clamp is not available yet
147
+ # dt = tl.clamp(dt, dt_min, dt_max)
148
+ dt = tl.minimum(tl.maximum(dt, dt_min), dt_max)
149
+ dt = tl.where((offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0)
150
+ ddt = tl.where((offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), ddt, 0.0)
151
+ ddt = tl.where(clamp_mask, 0.0, ddt)
152
+ if DT_SOFTPLUS:
153
+ ddt = tl.where(dt_presoftplus <= 20.0, ddt * tl.sigmoid(dt_presoftplus), ddt)
154
+ tl.store(ddt_ptrs, ddt, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit))
155
+ dA = tl.sum(ddA * dt, axis=1)
156
+ tl.atomic_add(dA_ptr + offs_h * stride_dA_head, dA, mask=offs_h < nheads)
157
+ if HAS_DT_BIAS:
158
+ ddt_bias = tl.sum(ddt, axis=1)
159
+ tl.atomic_add(ddt_bias_ptr + offs_h * stride_ddt_bias_head, ddt_bias, mask=offs_h < nheads)
160
+
161
+
162
+ @triton.autotune(
163
+ configs=[
164
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8),
165
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
166
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
167
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
168
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
169
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
170
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
171
+ triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
172
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2),
173
+ ],
174
+ key=['hdim', 'dstate', 'chunk_size'],
175
+ )
176
+ @triton.jit
177
+ def _chunk_state_fwd_kernel(
178
+ # Pointers to matrices
179
+ x_ptr, b_ptr, states_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr,
180
+ # Matrix dimensions
181
+ hdim, dstate, chunk_size,
182
+ batch, seqlen, nheads_ngroups_ratio,
183
+ # Strides
184
+ stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,
185
+ stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate,
186
+ stride_states_batch, stride_states_chunk, stride_states_head, stride_states_hdim, stride_states_dstate,
187
+ stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,
188
+ stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
189
+ stride_seq_idx_batch, stride_seq_idx_seqlen,
190
+ # Meta-parameters
191
+ HAS_SEQ_IDX: tl.constexpr,
192
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
193
+ ):
194
+ pid_bc = tl.program_id(axis=1)
195
+ pid_c = pid_bc // batch
196
+ pid_b = pid_bc - pid_c * batch
197
+ pid_h = tl.program_id(axis=2)
198
+ num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
199
+ pid_m = tl.program_id(axis=0) // num_pid_n
200
+ pid_n = tl.program_id(axis=0) % num_pid_n
201
+ b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head
202
+ x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
203
+ dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
204
+ dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
205
+ if HAS_SEQ_IDX:
206
+ seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
207
+
208
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
209
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
210
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
211
+ x_ptrs = x_ptr + (offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen)
212
+ b_ptrs = b_ptr + (offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen)
213
+ dt_ptrs = dt_ptr + offs_k * stride_dt_csize
214
+ dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32)
215
+ dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
216
+ if HAS_SEQ_IDX:
217
+ seq_idx_ptrs = seq_idx_ptr + offs_k * stride_seq_idx_seqlen
218
+
219
+ chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
220
+ if HAS_SEQ_IDX:
221
+ seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen)
222
+
223
+ acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
224
+ for k in range(0, chunk_size_limit, BLOCK_SIZE_K):
225
+ x = tl.load(x_ptrs, mask=(offs_m[:, None] < hdim) & (offs_k[None, :] < chunk_size_limit - k), other=0.0)
226
+ b = tl.load(b_ptrs, mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)
227
+ dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32)
228
+ if HAS_SEQ_IDX:
229
+ seq_idx_k = tl.load(seq_idx_ptrs, mask=offs_k < chunk_size_limit - k, other=-1)
230
+ dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32)
231
+ if not HAS_SEQ_IDX:
232
+ # scale = tl.exp((dA_cs_last - dA_cs_k)) * dt_k
233
+ scale = tl.exp(tl.minimum((dA_cs_last - dA_cs_k), 0.0)) * dt_k
234
+ else:
235
+ # scale = tl.where(seq_idx_k == seq_idx_last, tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0)
236
+ scale = tl.where((seq_idx_last >= 0) & (seq_idx_k == seq_idx_last), tl.exp(tl.minimum((dA_cs_last - dA_cs_k), 0.0)) * dt_k, 0.0)
237
+ b *= scale[:, None]
238
+ b = b.to(x_ptr.dtype.element_ty)
239
+ acc += tl.dot(x, b)
240
+ x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
241
+ b_ptrs += BLOCK_SIZE_K * stride_b_seqlen
242
+ dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
243
+ dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
244
+ if HAS_SEQ_IDX:
245
+ seq_idx_ptrs += BLOCK_SIZE_K * stride_seq_idx_seqlen
246
+ states = acc.to(states_ptr.dtype.element_ty)
247
+
248
+ states_ptr += pid_b * stride_states_batch + pid_c * stride_states_chunk + pid_h * stride_states_head
249
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
250
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
251
+ states_ptrs = states_ptr + (offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate)
252
+ c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate)
253
+ tl.store(states_ptrs, states, mask=c_mask)
254
+
255
+
256
+ @triton.autotune(
257
+ configs=[
258
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])),
259
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])),
260
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])),
261
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])),
262
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])),
263
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])),
264
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])),
265
+ triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])),
266
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr", "ddA_cumsum_ptr"])),
267
+ ],
268
+ key=['chunk_size', 'hdim', 'dstate'],
269
+ )
270
+ @triton.jit
271
+ def _chunk_state_bwd_dx_kernel(
272
+ # Pointers to matrices
273
+ x_ptr, b_ptr, dstates_ptr, dt_ptr, dA_cumsum_ptr,
274
+ dx_ptr, ddt_ptr, ddA_cumsum_ptr,
275
+ # Matrix dimensions
276
+ chunk_size, hdim, dstate,
277
+ batch, seqlen, nheads_ngroups_ratio,
278
+ # Strides
279
+ stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,
280
+ stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate,
281
+ stride_dstates_batch, stride_dstates_chunk, stride_states_head, stride_states_hdim, stride_states_dstate,
282
+ stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,
283
+ stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
284
+ stride_dx_batch, stride_dx_seqlen, stride_dx_head, stride_dx_hdim,
285
+ stride_ddt_batch, stride_ddt_chunk, stride_ddt_head, stride_ddt_csize,
286
+ stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize,
287
+ # Meta-parameters
288
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
289
+ BLOCK_SIZE_DSTATE: tl.constexpr,
290
+ ):
291
+ pid_bc = tl.program_id(axis=1)
292
+ pid_c = pid_bc // batch
293
+ pid_b = pid_bc - pid_c * batch
294
+ pid_h = tl.program_id(axis=2)
295
+ num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
296
+ pid_m = tl.program_id(axis=0) // num_pid_n
297
+ pid_n = tl.program_id(axis=0) % num_pid_n
298
+ x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
299
+ b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head
300
+ dstates_ptr += pid_b * stride_dstates_batch + pid_c * stride_dstates_chunk + pid_h * stride_states_head
301
+ dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
302
+ ddt_ptr += pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head
303
+ ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head
304
+ dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
305
+
306
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
307
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
308
+
309
+ chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
310
+ # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
311
+ offs_k = tl.arange(0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K)
312
+ b_ptrs = b_ptr + (offs_m[:, None] * stride_b_seqlen + offs_k[None, :] * stride_b_dstate)
313
+ dstates_ptrs = dstates_ptr + (offs_n[None, :] * stride_states_hdim + offs_k[:, None] * stride_states_dstate)
314
+ if BLOCK_SIZE_DSTATE <= 128:
315
+ b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate), other=0.0)
316
+ dstates = tl.load(dstates_ptrs, mask=(offs_k[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0)
317
+ dstates = dstates.to(b_ptr.dtype.element_ty)
318
+ acc = tl.dot(b, dstates)
319
+ else:
320
+ acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
321
+ for k in range(0, dstate, BLOCK_SIZE_K):
322
+ b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate - k), other=0.0)
323
+ dstates = tl.load(dstates_ptrs, mask=(offs_k[:, None] < dstate - k) & (offs_n[None, :] < hdim), other=0.0)
324
+ dstates = dstates.to(b_ptr.dtype.element_ty)
325
+ acc += tl.dot(b, dstates)
326
+ b_ptrs += BLOCK_SIZE_K * stride_b_dstate
327
+ dstates_ptrs += BLOCK_SIZE_K * stride_states_dstate
328
+
329
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
330
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
331
+
332
+ dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32)
333
+ dt_ptrs = dt_ptr + offs_m * stride_dt_csize
334
+ dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize
335
+ dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
336
+ dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
337
+ # acc *= tl.exp(dA_cs_last - dA_cs_m)[:, None]
338
+ acc *= tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0))[:, None]
339
+
340
+ x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)
341
+ x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
342
+ ddt = tl.sum(acc * x, axis=1)
343
+ ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize
344
+ tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size)
345
+ ddA_cs = -(ddt * dt_m)
346
+ ddA_cs_last = -tl.sum(ddA_cs)
347
+ ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize
348
+ tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size)
349
+ tl.atomic_add(ddA_cumsum_ptr + (chunk_size - 1) * stride_ddA_cs_csize, ddA_cs_last)
350
+
351
+ dx = (acc * dt_m[:, None]).to(dx_ptr.dtype.element_ty)
352
+ dx_ptr += pid_b * stride_dx_batch + pid_c * chunk_size * stride_dx_seqlen + pid_h * stride_dx_head
353
+ dx_ptrs = dx_ptr + (offs_m[:, None] * stride_dx_seqlen + offs_n[None, :] * stride_dx_hdim)
354
+ tl.store(dx_ptrs, dx, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim))
355
+
356
+
357
+ @triton.autotune(
358
+ configs=[
359
+ triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
360
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
361
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
362
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
363
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
364
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
365
+ triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
366
+ triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
367
+ ],
368
+ key=['chunk_size', 'dstate', 'hdim'],
369
+ )
370
+ @triton.jit
371
+ def _chunk_state_bwd_db_kernel(
372
+ # Pointers to matrices
373
+ x_ptr, dstates_ptr, b_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr,
374
+ db_ptr, ddA_cumsum_ptr,
375
+ # Matrix dimensions
376
+ chunk_size, dstate, hdim,
377
+ batch, seqlen, nheads, nheads_per_program, ngroups,
378
+ # Strides
379
+ stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,
380
+ stride_dstates_batch, stride_dstates_chunk, stride_states_head, stride_states_hdim, stride_states_dstate,
381
+ stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate,
382
+ stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,
383
+ stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
384
+ stride_seq_idx_batch, stride_seq_idx_seqlen,
385
+ stride_db_batch, stride_db_seqlen, stride_db_split, stride_db_group, stride_db_dstate,
386
+ stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize,
387
+ # Meta-parameters
388
+ HAS_DDA_CS: tl.constexpr,
389
+ HAS_SEQ_IDX: tl.constexpr,
390
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
391
+ ):
392
+ pid_bc = tl.program_id(axis=1)
393
+ pid_c = pid_bc // batch
394
+ pid_b = pid_bc - pid_c * batch
395
+ pid_sg = tl.program_id(axis=2)
396
+ pid_s = pid_sg // ngroups
397
+ pid_g = pid_sg - pid_s * ngroups
398
+ num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
399
+ pid_m = tl.program_id(axis=0) // num_pid_n
400
+ pid_n = tl.program_id(axis=0) % num_pid_n
401
+ x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_x_head
402
+ db_ptr += pid_b * stride_db_batch + pid_c * chunk_size * stride_db_seqlen + pid_g * stride_db_group + pid_s * stride_db_split
403
+ dstates_ptr += pid_b * stride_dstates_batch + pid_c * stride_dstates_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_states_head
404
+ dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dt_head
405
+ dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dA_cs_head
406
+ if HAS_DDA_CS:
407
+ b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + pid_g * stride_b_head
408
+ ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_ddA_cs_head
409
+ if HAS_SEQ_IDX:
410
+ seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
411
+
412
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
413
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
414
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
415
+ x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_k[None, :] * stride_x_hdim)
416
+ dstates_ptrs = dstates_ptr + (offs_n[None, :] * stride_states_dstate + offs_k[:, None] * stride_states_hdim)
417
+ dt_ptrs = dt_ptr + offs_m * stride_dt_csize
418
+ dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize
419
+ if HAS_DDA_CS:
420
+ b_ptrs = b_ptr + (offs_m[:, None] * stride_b_seqlen + offs_n[None, :] * stride_b_dstate)
421
+ ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize
422
+
423
+ chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
424
+ acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
425
+ if HAS_DDA_CS:
426
+ b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)
427
+ if HAS_SEQ_IDX:
428
+ seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
429
+ seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen)
430
+ nheads_iter = min(nheads_per_program, nheads // ngroups - pid_s * nheads_per_program)
431
+ for h in range(nheads_iter):
432
+ x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0)
433
+ dstates = tl.load(dstates_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < dstate), other=0.0)
434
+ dstates = dstates.to(x_ptrs.dtype.element_ty)
435
+ db = tl.dot(x, dstates)
436
+ dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32)
437
+ dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
438
+ dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
439
+ if not HAS_SEQ_IDX:
440
+ # scale = tl.exp(dA_cs_last - dA_cs_m)
441
+ scale = tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0))
442
+ else:
443
+ # scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)
444
+ scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0)), 0.0)
445
+ db *= (scale * dt_m)[:, None]
446
+ if HAS_DDA_CS:
447
+ # This is the gradient wrt (dA_cs_last - dA_cs_m), i.e. the exclusive reverse cumsum
448
+ ddA_cs = tl.sum(db * b, axis=1)
449
+ tl.atomic_add(ddA_cumsum_ptrs + stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size - 1)
450
+ acc += db
451
+ x_ptrs += stride_x_head
452
+ dstates_ptrs += stride_states_head
453
+ dt_ptrs += stride_dt_head
454
+ dA_cumsum_ptr += stride_dA_cs_head
455
+ dA_cumsum_ptrs += stride_dA_cs_head
456
+ if HAS_DDA_CS:
457
+ ddA_cumsum_ptrs += stride_ddA_cs_head
458
+
459
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
460
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
461
+ # if HAS_SEQ_IDX:
462
+ # seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen)
463
+ # seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
464
+ # acc = tl.where(seq_idx_m[:, None] == seq_idx_last, acc, 0.0)
465
+ db_ptrs = db_ptr + (offs_m[:, None] * stride_db_seqlen + offs_n[None, :] * stride_db_dstate)
466
+ tl.store(db_ptrs, acc, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate))
467
+
468
+
469
+ @triton.autotune(
470
+ configs=[
471
+ # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
472
+ # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
473
+ # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
474
+ # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
475
+ # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
476
+ # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
477
+ # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
478
+ # triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
479
+ # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
480
+ triton.Config({'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
481
+ triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
482
+ triton.Config({'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
483
+ triton.Config({'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
484
+ triton.Config({'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
485
+ triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
486
+ triton.Config({'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
487
+ triton.Config({'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
488
+ ],
489
+ key=['chunk_size', 'hdim', 'dstate'],
490
+ )
491
+ @triton.jit
492
+ def _chunk_state_bwd_ddAcs_stable_kernel(
493
+ # Pointers to matrices
494
+ x_ptr, b_ptr, dstates_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr,
495
+ ddA_cumsum_ptr,
496
+ # Matrix dimensions
497
+ chunk_size, hdim, dstate,
498
+ batch, seqlen, nheads_ngroups_ratio,
499
+ # Strides
500
+ stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,
501
+ stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate,
502
+ stride_dstates_batch, stride_dstates_chunk, stride_states_head, stride_states_hdim, stride_states_dstate,
503
+ stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,
504
+ stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
505
+ stride_seq_idx_batch, stride_seq_idx_seqlen,
506
+ stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize,
507
+ # Meta-parameters
508
+ HAS_SEQ_IDX: tl.constexpr,
509
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
510
+ BLOCK_SIZE_DSTATE: tl.constexpr,
511
+ ):
512
+ pid_bc = tl.program_id(axis=1)
513
+ pid_c = pid_bc // batch
514
+ pid_b = pid_bc - pid_c * batch
515
+ pid_h = tl.program_id(axis=2)
516
+ num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
517
+ pid_m = tl.program_id(axis=0) // num_pid_n
518
+ pid_n = tl.program_id(axis=0) % num_pid_n
519
+ x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
520
+ b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head
521
+ dstates_ptr += pid_b * stride_dstates_batch + pid_c * stride_dstates_chunk + pid_h * stride_states_head
522
+ dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
523
+ ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head
524
+ dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
525
+ if HAS_SEQ_IDX:
526
+ seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
527
+
528
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
529
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
530
+
531
+ chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
532
+ # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
533
+ offs_k = tl.arange(0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K)
534
+ b_ptrs = b_ptr + (offs_m[:, None] * stride_b_seqlen + offs_k[None, :] * stride_b_dstate)
535
+ dstates_ptrs = dstates_ptr + (offs_n[None, :] * stride_states_hdim + offs_k[:, None] * stride_states_dstate)
536
+ if BLOCK_SIZE_DSTATE <= 128:
537
+ b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate), other=0.0)
538
+ dstates = tl.load(dstates_ptrs, mask=(offs_k[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0)
539
+ dstates = dstates.to(b_ptr.dtype.element_ty)
540
+ acc = tl.dot(b, dstates)
541
+ else:
542
+ acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
543
+ for k in range(0, dstate, BLOCK_SIZE_K):
544
+ b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < dstate - k), other=0.0)
545
+ dstates = tl.load(dstates_ptrs, mask=(offs_k[:, None] < dstate - k) & (offs_n[None, :] < hdim), other=0.0)
546
+ dstates = dstates.to(b_ptr.dtype.element_ty)
547
+ acc += tl.dot(b, dstates)
548
+ b_ptrs += BLOCK_SIZE_K * stride_b_dstate
549
+ dstates_ptrs += BLOCK_SIZE_K * stride_states_dstate
550
+
551
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
552
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
553
+
554
+ dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
555
+ dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32)
556
+ if not HAS_SEQ_IDX:
557
+ # scale = tl.exp(dA_cs_last - dA_cs_m)
558
+ scale = tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0))
559
+ else:
560
+ seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
561
+ seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen)
562
+ # scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)
563
+ scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0)), 0.0)
564
+ acc *= scale[:, None]
565
+
566
+ x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)
567
+ x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
568
+ dt_ptrs = dt_ptr + offs_m * stride_dt_csize
569
+ dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
570
+ ddt = tl.sum(acc * x, axis=1)
571
+ # ddA_cs = -(ddt * dt_m)
572
+ # Triton 2.2.0 errors if we have the cumsum here, so we just write it out
573
+ # then call torch.cumsum outside this kernel.
574
+ # ddA_cs = tl.cumsum(ddt * dt_m)
575
+ ddA_cs = ddt * dt_m
576
+ ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize
577
+ # tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size)
578
+ tl.atomic_add(ddA_cumsum_ptrs + stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size - 1)
579
+
580
+
581
+ @triton.autotune(
582
+ configs=[
583
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8),
584
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
585
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
586
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
587
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
588
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
589
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
590
+ triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
591
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2),
592
+ ],
593
+ key=['hdim', 'dstate', 'chunk_size'],
594
+ )
595
+ @triton.jit
596
+ def _chunk_state_varlen_kernel(
597
+ # Pointers to matrices
598
+ x_ptr, b_ptr, dt_ptr, dA_cumsum_ptr, chunk_states_ptr, cu_seqlens_ptr, states_ptr,
599
+ # Matrix dimensions
600
+ hdim, dstate, chunk_size,
601
+ seqlen, nheads_ngroups_ratio,
602
+ # Strides
603
+ stride_x_seqlen, stride_x_head, stride_x_hdim,
604
+ stride_b_seqlen, stride_b_head, stride_b_dstate,
605
+ stride_dt_chunk, stride_dt_head, stride_dt_csize,
606
+ stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
607
+ stride_chunk_states_chunk, stride_chunk_states_head, stride_chunk_states_hdim, stride_chunk_states_dstate,
608
+ stride_states_batch, stride_states_head, stride_states_hdim, stride_states_dstate,
609
+ # Meta-parameters
610
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
611
+ ):
612
+ pid_b = tl.program_id(axis=1)
613
+ pid_h = tl.program_id(axis=2)
614
+ num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
615
+ pid_m = tl.program_id(axis=0) // num_pid_n
616
+ pid_n = tl.program_id(axis=0) % num_pid_n
617
+ end_idx = tl.load(cu_seqlens_ptr + pid_b + 1)
618
+ pid_c = (end_idx - 1) // chunk_size
619
+ b_ptr += pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head
620
+ x_ptr += pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
621
+ dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head
622
+ dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
623
+ chunk_states_ptr += pid_c * stride_chunk_states_chunk + pid_h * stride_chunk_states_head
624
+
625
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
626
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
627
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
628
+ x_ptrs = x_ptr + (offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen)
629
+ b_ptrs = b_ptr + (offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen)
630
+ dt_ptrs = dt_ptr + offs_k * stride_dt_csize
631
+ dA_cs_last = tl.load(dA_cumsum_ptr + (end_idx - pid_c * chunk_size - 1) * stride_dA_cs_csize).to(tl.float32)
632
+ dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
633
+
634
+ chunk_size_limit = end_idx - pid_c * chunk_size
635
+ start_idx = tl.load(cu_seqlens_ptr + pid_b)
636
+ start_idx_cur = tl.maximum(start_idx - pid_c * chunk_size, 0)
637
+
638
+ acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
639
+ for k in range(0, chunk_size_limit, BLOCK_SIZE_K):
640
+ x = tl.load(x_ptrs, mask=(offs_m[:, None] < hdim) & (offs_k[None, :] < chunk_size_limit - k) & (offs_k[None, :] >= start_idx_cur - k), other=0.0)
641
+ b = tl.load(b_ptrs, mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate) & (offs_k[:, None] >= start_idx_cur - k), other=0.0).to(tl.float32)
642
+ dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32)
643
+ dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32)
644
+ # scale = tl.where((offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k),
645
+ # tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0)
646
+ scale = tl.where((offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k),
647
+ tl.exp(tl.minimum((dA_cs_last - dA_cs_k), 0.0)) * dt_k, 0.0)
648
+ b *= scale[:, None]
649
+ b = b.to(x_ptr.dtype.element_ty)
650
+ acc += tl.dot(x, b)
651
+ x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
652
+ b_ptrs += BLOCK_SIZE_K * stride_b_seqlen
653
+ dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
654
+ dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
655
+
656
+ # If the sequence starts after the last chunk idx, we don't need to add the contribution from the last chunk
657
+ if start_idx < pid_c * chunk_size:
658
+ chunk_states_ptrs = chunk_states_ptr + (offs_m[:, None] * stride_chunk_states_hdim + offs_n[None, :] * stride_chunk_states_dstate)
659
+ chunk_states = tl.load(chunk_states_ptrs, mask=(offs_m[:, None] < hdim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)
660
+ # scale = tl.where(start_idx < pid_c * chunk_size, tl.exp(dA_cs_last), 0.0)
661
+ scale = tl.exp(dA_cs_last)
662
+ acc += chunk_states * scale
663
+
664
+ states = acc.to(states_ptr.dtype.element_ty)
665
+
666
+ states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head
667
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
668
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
669
+ states_ptrs = states_ptr + (offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate)
670
+ c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate)
671
+ tl.store(states_ptrs, states, mask=c_mask)
672
+
673
+
674
+ def _chunk_cumsum_fwd(dt, A, chunk_size, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf"))):
675
+ batch, seqlen, nheads = dt.shape
676
+ assert A.shape == (nheads,)
677
+ if dt_bias is not None:
678
+ assert dt_bias.shape == (nheads,)
679
+ nchunks = math.ceil(seqlen / chunk_size)
680
+ dt_out = torch.empty(batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32)
681
+ dA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32)
682
+ grid_chunk_cs = lambda META: (batch, nchunks, triton.cdiv(nheads, META['BLOCK_SIZE_H']))
683
+ with torch.cuda.device(dt.device.index):
684
+ _chunk_cumsum_fwd_kernel[grid_chunk_cs](
685
+ dt, A, dt_bias, dt_out, dA_cumsum,
686
+ batch, seqlen, nheads, chunk_size,
687
+ dt_limit[0], dt_limit[1],
688
+ dt.stride(0), dt.stride(1), dt.stride(2),
689
+ A.stride(0),
690
+ dt_bias.stride(0) if dt_bias is not None else 0,
691
+ dt_out.stride(0), dt_out.stride(2), dt_out.stride(1), dt_out.stride(3),
692
+ dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
693
+ dt_softplus,
694
+ HAS_DT_BIAS=dt_bias is not None,
695
+ BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size),
696
+ )
697
+ return dA_cumsum, dt_out
698
+
699
+
700
+ def _chunk_cumsum_bwd(ddA, ddt_out, dt, A, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf")), ddt=None):
701
+ batch, seqlen, nheads = dt.shape
702
+ _, _, nchunks, chunk_size = ddA.shape
703
+ assert ddA.shape == (batch, nheads, nchunks, chunk_size)
704
+ assert ddt_out.shape == (batch, nheads, nchunks, chunk_size)
705
+ assert A.shape == (nheads,)
706
+ if dt_bias is not None:
707
+ assert dt_bias.shape == (nheads,)
708
+ ddt_bias = torch.empty_like(dt_bias, dtype=torch.float32)
709
+ else:
710
+ ddt_bias = None
711
+ if ddt is not None:
712
+ assert ddt.shape == dt.shape
713
+ else:
714
+ ddt = torch.empty_like(dt)
715
+ dA = torch.empty_like(A, dtype=torch.float32)
716
+ grid_chunk_cs = lambda META: (batch, nchunks, triton.cdiv(nheads, META['BLOCK_SIZE_H']))
717
+ with torch.cuda.device(dt.device.index):
718
+ _chunk_cumsum_bwd_kernel[grid_chunk_cs](
719
+ ddA, ddt_out, dt, A, dt_bias, ddt, dA, ddt_bias,
720
+ batch, seqlen, nheads, chunk_size,
721
+ dt_limit[0], dt_limit[1],
722
+ ddA.stride(0), ddA.stride(2), ddA.stride(1), ddA.stride(3),
723
+ ddt_out.stride(0), ddt_out.stride(2), ddt_out.stride(1), ddt_out.stride(3),
724
+ dt.stride(0), dt.stride(1), dt.stride(2),
725
+ A.stride(0),
726
+ dt_bias.stride(0) if dt_bias is not None else 0,
727
+ ddt.stride(0), ddt.stride(1), ddt.stride(2),
728
+ dA.stride(0),
729
+ ddt_bias.stride(0) if ddt_bias is not None else 0,
730
+ dt_softplus,
731
+ HAS_DT_BIAS=dt_bias is not None,
732
+ BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size),
733
+ )
734
+ return ddt, dA, ddt_bias
735
+
736
+
737
+ def _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=None, states=None, states_in_fp32=True):
738
+ batch, seqlen, nheads, headdim = x.shape
739
+ _, _, nchunks, chunk_size = dt.shape
740
+ _, _, ngroups, dstate = B.shape
741
+ assert nheads % ngroups == 0
742
+ assert B.shape == (batch, seqlen, ngroups, dstate)
743
+ assert dt.shape == (batch, nheads, nchunks, chunk_size)
744
+ assert dA_cumsum.shape == dt.shape
745
+ if seq_idx is not None:
746
+ assert seq_idx.shape == (batch, seqlen)
747
+ if states is not None:
748
+ assert states.shape == (batch, nchunks, nheads, headdim, dstate)
749
+ else:
750
+ states_dtype = torch.float32 if states_in_fp32 else B.dtype
751
+ states = torch.empty((batch, nchunks, nheads, headdim, dstate), device=x.device, dtype=states_dtype)
752
+ grid = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']),
753
+ batch * nchunks, nheads)
754
+ with torch.cuda.device(x.device.index):
755
+ _chunk_state_fwd_kernel[grid](
756
+ x, B, states, dt, dA_cumsum, seq_idx,
757
+ headdim, dstate, chunk_size,
758
+ batch, seqlen, nheads // ngroups,
759
+ x.stride(0), x.stride(1), x.stride(2), x.stride(3),
760
+ B.stride(0), B.stride(1), B.stride(2), B.stride(-1),
761
+ states.stride(0), states.stride(1), states.stride(2), states.stride(3), states.stride(4),
762
+ dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),
763
+ dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
764
+ *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
765
+ HAS_SEQ_IDX=seq_idx is not None,
766
+ )
767
+ return states
768
+
769
+
770
+ def _chunk_state_bwd_dx(B, x, dt, dA_cumsum, dstates, dx=None):
771
+ batch, seqlen, nheads, headdim = x.shape
772
+ _, _, nchunks, chunk_size = dt.shape
773
+ _, _, ngroups, dstate = B.shape
774
+ assert nheads % ngroups == 0
775
+ assert B.shape == (batch, seqlen, ngroups, dstate)
776
+ assert dt.shape == (batch, nheads, nchunks, chunk_size)
777
+ assert dA_cumsum.shape == dt.shape
778
+ assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
779
+ if dx is not None:
780
+ assert dx.shape == x.shape
781
+ else:
782
+ dx = torch.empty_like(x)
783
+ ddt = torch.empty(batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32)
784
+ ddA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=dA_cumsum.device, dtype=torch.float32)
785
+ grid_dx = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']),
786
+ batch * nchunks, nheads)
787
+ with torch.cuda.device(x.device.index):
788
+ _chunk_state_bwd_dx_kernel[grid_dx](
789
+ x, B, dstates, dt, dA_cumsum, dx, ddt, ddA_cumsum,
790
+ chunk_size, headdim, dstate,
791
+ batch, seqlen, nheads // ngroups,
792
+ x.stride(0), x.stride(1), x.stride(2), x.stride(3),
793
+ B.stride(0), B.stride(1), B.stride(2), B.stride(-1),
794
+ dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3), dstates.stride(4),
795
+ dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),
796
+ dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
797
+ dx.stride(0), dx.stride(1), dx.stride(2), dx.stride(3),
798
+ ddt.stride(0), ddt.stride(2), ddt.stride(1), ddt.stride(3),
799
+ ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3),
800
+ BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
801
+ )
802
+ return dx, ddt.to(dt.dtype), ddA_cumsum.to(dA_cumsum.dtype)
803
+
804
+
805
+ def _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, seq_idx=None, B=None, ngroups=1):
806
+ batch, seqlen, nheads, headdim = x.shape
807
+ _, _, nchunks, chunk_size = dt.shape
808
+ dstate = dstates.shape[-1]
809
+ assert dt.shape == (batch, nheads, nchunks, chunk_size)
810
+ assert dA_cumsum.shape == dt.shape
811
+ assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
812
+ if seq_idx is not None:
813
+ assert seq_idx.shape == (batch, seqlen)
814
+ if B is not None:
815
+ assert B.shape == (batch, seqlen, ngroups, dstate)
816
+ B_strides = (B.stride(0), B.stride(1), B.stride(2), B.stride(3))
817
+ # Use torch.empty since the Triton kernel will call init_to_zero
818
+ ddA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=x.device, dtype=torch.float32)
819
+ ddA_cumsum_strides = (ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3))
820
+ else:
821
+ B_strides = (0, 0, 0, 0)
822
+ ddA_cumsum = None
823
+ ddA_cumsum_strides = (0, 0, 0, 0)
824
+ nheads_ngroups_ratio = nheads // ngroups
825
+ sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
826
+ nheads_per_program = max(min(math.ceil(batch * nchunks * nheads / sm_count), nheads_ngroups_ratio), 1)
827
+ nsplits = triton.cdiv(nheads_ngroups_ratio, nheads_per_program)
828
+ dB = torch.empty(batch, seqlen, nsplits, ngroups, dstate, device=x.device, dtype=torch.float32)
829
+ grid_db = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']),
830
+ batch * nchunks, nsplits * ngroups)
831
+ with torch.cuda.device(x.device.index):
832
+ _chunk_state_bwd_db_kernel[grid_db](
833
+ x, dstates, B, dt, dA_cumsum, seq_idx, dB, ddA_cumsum,
834
+ chunk_size, dstate, headdim,
835
+ batch, seqlen, nheads, nheads_per_program, ngroups,
836
+ x.stride(0), x.stride(1), x.stride(2), x.stride(3),
837
+ dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3), dstates.stride(4),
838
+ *B_strides,
839
+ dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),
840
+ dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
841
+ *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
842
+ dB.stride(0), dB.stride(1), dB.stride(2), dB.stride(3), dB.stride(4),
843
+ *ddA_cumsum_strides,
844
+ HAS_DDA_CS=ddA_cumsum is not None,
845
+ HAS_SEQ_IDX=seq_idx is not None,
846
+ BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16),
847
+ )
848
+ dB = dB.sum(2)
849
+ if ddA_cumsum is not None:
850
+ # The first element of ddA_cumsum is always zero, since that dA_cumsum does not contribute
851
+ # to the state of the chunk.
852
+ # torch.cumsum(ddA_cumsum[..., 1:], dim=-1, out=ddA_cumsum[..., 1:])
853
+ # But it's easier to just do the cumsum for all elements, the result will be the same.
854
+ torch.cumsum(ddA_cumsum, dim=-1, out=ddA_cumsum)
855
+ return dB if B is None else (dB, ddA_cumsum)
856
+
857
+
858
+ def _chunk_state_bwd_ddAcs_stable(B, x, dt, dA_cumsum, dstates, seq_idx=None):
859
+ batch, seqlen, nheads, headdim = x.shape
860
+ _, _, nchunks, chunk_size = dt.shape
861
+ _, _, ngroups, dstate = B.shape
862
+ assert nheads % ngroups == 0
863
+ assert B.shape == (batch, seqlen, ngroups, dstate)
864
+ assert dt.shape == (batch, nheads, nchunks, chunk_size)
865
+ assert dA_cumsum.shape == dt.shape
866
+ assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
867
+ if seq_idx is not None:
868
+ assert seq_idx.shape == (batch, seqlen)
869
+ # Use torch.empty since the Triton kernel will call init_to_zero
870
+ ddA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=x.device, dtype=torch.float32)
871
+ grid_ddtcs = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']),
872
+ batch * nchunks, nheads)
873
+ with torch.cuda.device(x.device.index):
874
+ _chunk_state_bwd_ddAcs_stable_kernel[grid_ddtcs](
875
+ x, B, dstates, dt, dA_cumsum, seq_idx, ddA_cumsum,
876
+ chunk_size, headdim, dstate,
877
+ batch, seqlen, nheads // ngroups,
878
+ x.stride(0), x.stride(1), x.stride(2), x.stride(3),
879
+ B.stride(0), B.stride(1), B.stride(2), B.stride(-1),
880
+ dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3), dstates.stride(4),
881
+ dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),
882
+ dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
883
+ *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
884
+ ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3),
885
+ HAS_SEQ_IDX=seq_idx is not None,
886
+ BLOCK_SIZE_M=max(triton.next_power_of_2(chunk_size), 16),
887
+ BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
888
+ )
889
+ torch.cumsum(ddA_cumsum[..., 1:], dim=-1, out=ddA_cumsum[..., 1:])
890
+ return ddA_cumsum
891
+
892
+
893
+ def chunk_state_varlen(B, x, dt, dA_cumsum, cu_seqlens, chunk_states):
894
+ total_seqlen, nheads, headdim = x.shape
895
+ _, nchunks, chunk_size = dt.shape
896
+ _, ngroups, dstate = B.shape
897
+ batch = cu_seqlens.shape[0] - 1
898
+ cu_seqlens = cu_seqlens.contiguous()
899
+ assert nheads % ngroups == 0
900
+ assert B.shape == (total_seqlen, ngroups, dstate)
901
+ assert dt.shape == (nheads, nchunks, chunk_size)
902
+ assert dA_cumsum.shape == dt.shape
903
+ assert chunk_states.shape == (nchunks, nheads, headdim, dstate)
904
+ states = torch.empty(batch, nheads, headdim, dstate, dtype=chunk_states.dtype, device=chunk_states.device)
905
+ grid = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']),
906
+ batch, nheads)
907
+ with torch.cuda.device(x.device.index):
908
+ _chunk_state_varlen_kernel[grid](
909
+ x, B, dt, dA_cumsum, chunk_states, cu_seqlens, states,
910
+ headdim, dstate, chunk_size,
911
+ total_seqlen, nheads // ngroups,
912
+ x.stride(0), x.stride(1), x.stride(2),
913
+ B.stride(0), B.stride(1), B.stride(2),
914
+ dt.stride(1), dt.stride(0), dt.stride(2),
915
+ dA_cumsum.stride(1), dA_cumsum.stride(0), dA_cumsum.stride(2),
916
+ chunk_states.stride(0), chunk_states.stride(1), chunk_states.stride(2), chunk_states.stride(3),
917
+ states.stride(0), states.stride(1), states.stride(2), states.stride(3),
918
+ )
919
+ return states
920
+
921
+
922
+ class ChunkStateFn(torch.autograd.Function):
923
+
924
+ @staticmethod
925
+ def forward(ctx, B, x, dt, dA_cumsum, states_in_fp32=True):
926
+ batch, seqlen, nheads, headdim = x.shape
927
+ _, _, nchunks, chunk_size = dt.shape
928
+ assert seqlen <= nchunks * chunk_size
929
+ _, _, ngroups, dstate = B.shape
930
+ assert B.shape == (batch, seqlen, ngroups, dstate)
931
+ assert dt.shape == (batch, nheads, nchunks, chunk_size)
932
+ assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
933
+ if B.stride(-1) != 1:
934
+ B = B.contiguous()
935
+ if x.stride(-1) != 1 and x.stride(1) != 1: # Either M or K dimension should be contiguous
936
+ x = x.contiguous()
937
+ states = _chunk_state_fwd(B, x, dt, dA_cumsum, states_in_fp32=states_in_fp32)
938
+ ctx.save_for_backward(B, x, dt, dA_cumsum)
939
+ return states
940
+
941
+ @staticmethod
942
+ def backward(ctx, dstates):
943
+ B, x, dt, dA_cumsum = ctx.saved_tensors
944
+ batch, seqlen, nheads, headdim = x.shape
945
+ _, _, nchunks, chunk_size = dt.shape
946
+ _, _, ngroups, dstate = B.shape
947
+ assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
948
+ if dstates.stride(-1) != 1:
949
+ dstates = dstates.contiguous()
950
+ dx, ddt, ddA_cumsum = _chunk_state_bwd_dx(B, x, dt, dA_cumsum, dstates)
951
+ dB = _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, ngroups=ngroups)
952
+ dB = dB.to(B.dtype)
953
+ return dB, dx, ddt, ddA_cumsum, None
954
+
955
+
956
+ def chunk_state(B, x, dt, dA_cumsum, states_in_fp32=True):
957
+ """
958
+ Argument:
959
+ B: (batch, seqlen, ngroups, headdim)
960
+ x: (batch, seqlen, nheads, headdim)
961
+ dt: (batch, nheads, nchunks, chunk_size)
962
+ dA_cumsum: (batch, nheads, nchunks, chunk_size)
963
+ Return:
964
+ states: (batch, nchunks, nheads, headdim, dstate)
965
+ """
966
+ return ChunkStateFn.apply(B, x, dt, dA_cumsum, states_in_fp32)
967
+
968
+
969
+ def chunk_state_ref(B, x, dt, dA_cumsum):
970
+ """
971
+ Argument:
972
+ B: (batch, seqlen, ngroups, headdim)
973
+ x: (batch, seqlen, nheads, headdim)
974
+ dt: (batch, nheads, nchunks, chunk_size)
975
+ dA_cumsum: (batch, nheads, nchunks, chunk_size)
976
+ Return:
977
+ states: (batch, nchunks, nheads, headdim, dstate)
978
+ """
979
+ # Check constraints.
980
+ batch, seqlen, nheads, headdim = x.shape
981
+ dstate = B.shape[-1]
982
+ _, _, nchunks, chunk_size = dt.shape
983
+ assert seqlen <= nchunks * chunk_size
984
+ assert x.shape == (batch, seqlen, nheads, headdim)
985
+ assert dt.shape == (batch, nheads, nchunks, chunk_size)
986
+ ngroups = B.shape[2]
987
+ assert nheads % ngroups == 0
988
+ assert B.shape == (batch, seqlen, ngroups, dstate)
989
+ B = repeat(B, "b l g d -> b l (g h) d", h=nheads // ngroups)
990
+ assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
991
+ if seqlen < nchunks * chunk_size:
992
+ x = F.pad(x, (0, 0, 0, 0, 0, nchunks * chunk_size - seqlen))
993
+ B = F.pad(B, (0, 0, 0, 0, 0, nchunks * chunk_size - seqlen))
994
+ x = rearrange(x, "b (c l) h p -> b c l h p", l=chunk_size)
995
+ B = rearrange(B, "b (c l) ... -> b c l ...", l=chunk_size)
996
+ decay_states = torch.exp((dA_cumsum[:, :, :, -1:] - dA_cumsum))
997
+ return torch.einsum("bclhn,bhcl,bhcl,bclhp->bchpn", B.to(x.dtype), decay_states.to(x.dtype), dt.to(x.dtype), x)
build/torch210-cxx11-cu126-x86_64-linux/ops/triton/ssd_combined.py ADDED
@@ -0,0 +1,998 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao, Albert Gu.
2
+
3
+ """We want triton==2.1.0 or 2.2.0 for this
4
+ """
5
+
6
+ from typing import Optional
7
+
8
+ import math
9
+ from packaging import version
10
+
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from torch import Tensor
14
+ from ...utils.torch import custom_bwd, custom_fwd
15
+
16
+ import triton
17
+ import triton.language as tl
18
+
19
+ from einops import rearrange, repeat
20
+
21
+ try:
22
+ from causal_conv1d import causal_conv1d_fn
23
+ from causal_conv1d.causal_conv1d_interface import causal_conv1d_cuda
24
+ except ImportError:
25
+ causal_conv1d_fn = None
26
+ causal_conv1d_cuda = None
27
+
28
+ from .ssd_bmm import _bmm_chunk_fwd, _bmm_chunk_bwd
29
+ from .ssd_chunk_state import _chunk_cumsum_fwd, _chunk_cumsum_bwd
30
+ from .ssd_chunk_state import _chunk_state_fwd, _chunk_state_bwd_db
31
+ from .ssd_chunk_state import _chunk_state_bwd_ddAcs_stable
32
+ from .ssd_chunk_state import chunk_state, chunk_state_ref
33
+ from .ssd_chunk_state import chunk_state_varlen
34
+ from .ssd_state_passing import _state_passing_fwd, _state_passing_bwd
35
+ from .ssd_state_passing import state_passing, state_passing_ref
36
+ from .ssd_chunk_scan import _chunk_scan_fwd, _chunk_scan_bwd_dz, _chunk_scan_bwd_dstates
37
+ from .ssd_chunk_scan import _chunk_scan_bwd_dC, _chunk_scan_bwd_dcb
38
+ from .ssd_chunk_scan import _chunk_scan_bwd_ddAcs_stable
39
+ from .ssd_chunk_scan import chunk_scan, chunk_scan_ref
40
+ from .ssd_chunk_scan import _chunk_scan_bwd_ddAcs_prev
41
+ from .layernorm_gated import rmsnorm_fn, _layer_norm_fwd, _layer_norm_bwd
42
+ from .k_activations import _swiglu_fwd, _swiglu_bwd
43
+
44
+ TRITON_22 = version.parse(triton.__version__) >= version.parse("2.2.0")
45
+
46
+
47
+ def init_to_zero(names):
48
+ return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None]
49
+
50
+
51
+ def rearrange_and_update_stride(tensor, pattern=None, dim=2):
52
+ # ensure tensor.stride(dim) is a multiple of eight after rearranging according to pattern,
53
+ # if not call contiguous(), rearrange only if pattern is not None
54
+ tensor_rearranged = rearrange(tensor, pattern) if pattern is not None else tensor
55
+ return tensor_rearranged.contiguous() if tensor_rearranged.stride(dim) % 8 != 0 else tensor_rearranged
56
+
57
+
58
+ @triton.autotune(
59
+ configs=[
60
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddt_ptr"])),
61
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
62
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
63
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
64
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
65
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
66
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
67
+ triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
68
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
69
+ ],
70
+ key=['chunk_size', 'hdim', 'dstate'],
71
+ )
72
+ @triton.jit
73
+ def _chunk_scan_chunk_state_bwd_dx_kernel(
74
+ # Pointers to matrices
75
+ x_ptr, cb_ptr, dout_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, D_ptr,
76
+ b_ptr, dstates_ptr,
77
+ dx_ptr, ddt_ptr, dD_ptr,
78
+ # Matrix dimensions
79
+ chunk_size, hdim, dstate,
80
+ batch, seqlen, nheads_ngroups_ratio,
81
+ # Strides
82
+ stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,
83
+ stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_k,
84
+ stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim,
85
+ stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,
86
+ stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
87
+ stride_seq_idx_batch, stride_seq_idx_seqlen,
88
+ stride_D_head,
89
+ stride_b_batch, stride_b_seqlen, stride_b_head, stride_b_dstate,
90
+ stride_dstates_batch, stride_dstates_chunk, stride_dstates_head, stride_dstates_hdim, stride_dstates_dstate,
91
+ stride_dx_batch, stride_dx_seqlen, stride_dx_head, stride_dx_hdim,
92
+ stride_ddt_batch, stride_ddt_chunk, stride_ddt_head, stride_ddt_csize,
93
+ stride_dD_batch, stride_dD_chunk, stride_dD_head, stride_dD_csize, stride_dD_hdim,
94
+ # Meta-parameters
95
+ HAS_D: tl.constexpr,
96
+ D_HAS_HDIM: tl.constexpr,
97
+ HAS_SEQ_IDX: tl.constexpr,
98
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
99
+ BLOCK_SIZE_DSTATE: tl.constexpr,
100
+ IS_TRITON_22: tl.constexpr,
101
+ ):
102
+ pid_bc = tl.program_id(axis=1)
103
+ pid_c = pid_bc // batch
104
+ pid_b = pid_bc - pid_c * batch
105
+ pid_h = tl.program_id(axis=2)
106
+ num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
107
+ pid_m = tl.program_id(axis=0) // num_pid_n
108
+ pid_n = tl.program_id(axis=0) % num_pid_n
109
+ x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
110
+ cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head
111
+ dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head
112
+ dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
113
+ ddt_ptr += pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head
114
+ dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
115
+ b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head
116
+ dstates_ptr += pid_b * stride_dstates_batch + pid_c * stride_dstates_chunk + pid_h * stride_dstates_head
117
+ if HAS_SEQ_IDX:
118
+ seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
119
+
120
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
121
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
122
+
123
+ chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
124
+
125
+ acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
126
+
127
+ dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)
128
+
129
+ dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32)
130
+ if not HAS_SEQ_IDX:
131
+ # scale = tl.exp(dA_cs_last - dA_cs_m)
132
+ scale = tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0))
133
+ else:
134
+ seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
135
+ seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen)
136
+ # scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(dA_cs_last - dA_cs_m), 0.0)
137
+ scale = tl.where(seq_idx_m == seq_idx_last, tl.exp(tl.minimum((dA_cs_last - dA_cs_m), 0.0)), 0.0)
138
+ # Might be faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
139
+ # However, we're getting error with the Triton compiler 2.1.0 for that code path:
140
+ # Unexpected mma -> mma layout conversion
141
+ # Triton 2.2.0 fixes this
142
+ offs_dstate = tl.arange(0, BLOCK_SIZE_DSTATE if IS_TRITON_22 and BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K)
143
+ b_ptrs = b_ptr + (offs_m[:, None] * stride_b_seqlen + offs_dstate[None, :] * stride_b_dstate)
144
+ dstates_ptrs = dstates_ptr + (offs_n[None, :] * stride_dstates_hdim + offs_dstate[:, None] * stride_dstates_dstate)
145
+ if IS_TRITON_22 and BLOCK_SIZE_DSTATE <= 128:
146
+ b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_dstate[None, :] < dstate), other=0.0)
147
+ dstates = tl.load(dstates_ptrs, mask=(offs_dstate[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0)
148
+ dstates = dstates.to(b_ptr.dtype.element_ty)
149
+ acc = tl.dot(b, dstates) * scale[:, None]
150
+ else:
151
+ for k in range(0, dstate, BLOCK_SIZE_K):
152
+ b = tl.load(b_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_dstate[None, :] < dstate - k), other=0.0)
153
+ dstates = tl.load(dstates_ptrs, mask=(offs_dstate[:, None] < dstate - k) & (offs_n[None, :] < hdim), other=0.0)
154
+ dstates = dstates.to(b_ptr.dtype.element_ty)
155
+ acc += tl.dot(b, dstates)
156
+ b_ptrs += BLOCK_SIZE_K * stride_b_dstate
157
+ dstates_ptrs += BLOCK_SIZE_K * stride_dstates_dstate
158
+ acc *= scale[:, None]
159
+
160
+ # x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)
161
+ # x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
162
+ # dt_ptrs = dt_ptr + offs_m * stride_dt_csize
163
+ # dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)
164
+ # ddt = tl.sum(acc * x, axis=1) * dt_m
165
+ # ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize
166
+ # tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size)
167
+
168
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
169
+ cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k)
170
+ dout_ptrs = dout_ptr + (offs_k[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim)
171
+ dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
172
+ K_MAX = chunk_size_limit
173
+ K_MIN = pid_m * BLOCK_SIZE_M
174
+ cb_ptrs += K_MIN * stride_cb_csize_k
175
+ dout_ptrs += K_MIN * stride_dout_seqlen
176
+ dA_cumsum_ptrs += K_MIN * stride_dA_cs_csize
177
+ for k in range(K_MIN, K_MAX, BLOCK_SIZE_K):
178
+ k = tl.multiple_of(k, BLOCK_SIZE_K)
179
+ # For some reason setting mask to (offs_m[:, None] < chunk_size_limit) is much slower
180
+ cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < K_MAX - k), other=0.0)
181
+ dout = tl.load(dout_ptrs, mask=(offs_k[:, None] < K_MAX - k) & (offs_n[None, :] < hdim), other=0.0)
182
+ dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < K_MAX - k, other=0.0).to(tl.float32)
183
+ # cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None])
184
+ cb *= tl.exp(tl.minimum((dA_cs_k[None, :] - dA_cs_m[:, None]), 0.0))
185
+ # If we don't have the (k + offs_k[None, :] < K_MAX) mask, for indices outside this range,
186
+ # we might have dA_cs_m = 0.0 and dA_cs_k very negative, and tl.exp will return inf.
187
+ # Multiplying with cb, which is 0.0 outside the range, will make the result NaN.
188
+ # This will cause NaN in acc, and hence NaN in dx and ddt.
189
+ mask = (k + offs_k[None, :] >= offs_m[:, None]) & (k + offs_k[None, :] < K_MAX)
190
+ cb = tl.where(mask, cb, 0.0)
191
+ cb = cb.to(dout_ptr.dtype.element_ty)
192
+ acc += tl.dot(cb, dout)
193
+ cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k
194
+ dout_ptrs += BLOCK_SIZE_K * stride_dout_seqlen
195
+ dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
196
+
197
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
198
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
199
+ dt_ptrs = dt_ptr + offs_m * stride_dt_csize
200
+ dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)
201
+ dx = acc * dt_m[:, None]
202
+ dx_ptr += pid_b * stride_dx_batch + pid_c * chunk_size * stride_dx_seqlen + pid_h * stride_dx_head
203
+ dx_ptrs = dx_ptr + (offs_m[:, None] * stride_dx_seqlen + offs_n[None, :] * stride_dx_hdim)
204
+ if HAS_D:
205
+ dout_res_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim)
206
+ dout_res = tl.load(dout_res_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
207
+ if D_HAS_HDIM:
208
+ D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32)
209
+ else:
210
+ D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32)
211
+ dx += dout_res * D
212
+ tl.store(dx_ptrs, dx, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim))
213
+
214
+ x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)
215
+ x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
216
+ if HAS_D:
217
+ dD_ptr += pid_b * stride_dD_batch + pid_c * stride_dD_chunk + pid_h * stride_dD_head + pid_m * stride_dD_csize
218
+ if D_HAS_HDIM:
219
+ dD_ptrs = dD_ptr + offs_n * stride_dD_hdim
220
+ dD = tl.sum(dout_res * x, axis=0)
221
+ tl.store(dD_ptrs, dD, mask=offs_n < hdim)
222
+ else:
223
+ dD = tl.sum(dout_res * x)
224
+ tl.store(dD_ptr, dD)
225
+ ddt = tl.sum(acc * x, axis=1)
226
+ ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize
227
+ tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size)
228
+
229
+
230
+ def _chunk_scan_chunk_state_bwd_dx(x, dt, dA_cumsum, B, CB, dout, dstates, D=None, seq_idx=None, dx=None):
231
+ batch, seqlen, nheads, headdim = x.shape
232
+ _, _, nchunks, chunk_size = dt.shape
233
+ _, _, ngroups, dstate = B.shape
234
+ assert nheads % ngroups == 0
235
+ assert B.shape == (batch, seqlen, ngroups, dstate)
236
+ assert CB.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)
237
+ assert dt.shape == (batch, nheads, nchunks, chunk_size)
238
+ assert dA_cumsum.shape == dt.shape
239
+ assert dout.shape == x.shape
240
+ assert dstates.shape == (batch, nchunks, nheads, headdim, dstate)
241
+ if seq_idx is not None:
242
+ assert seq_idx.shape == (batch, seqlen)
243
+ if D is not None:
244
+ assert D.shape == (nheads, headdim) or D.shape == (nheads,)
245
+ assert D.stride(-1) == 1
246
+ BLOCK_SIZE_min = 32
247
+ dD = torch.empty(triton.cdiv(chunk_size, BLOCK_SIZE_min), batch, nchunks, nheads,
248
+ headdim if D.dim() == 2 else 1, device=D.device, dtype=torch.float32)
249
+ else:
250
+ dD = None
251
+ dD_strides = ((dD.stride(0), dD.stride(1), dD.stride(2), dD.stride(3), dD.stride(4))
252
+ if D is not None else (0, 0, 0, 0, 0))
253
+ if dx is None:
254
+ dx = torch.empty_like(x)
255
+ else:
256
+ assert dx.shape == x.shape
257
+ ddt = torch.empty(batch, nheads, nchunks, chunk_size, device=dout.device, dtype=torch.float32)
258
+ grid_dx = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']),
259
+ batch * nchunks, nheads)
260
+ with torch.cuda.device(x.device.index):
261
+ _chunk_scan_chunk_state_bwd_dx_kernel[grid_dx](
262
+ x, CB, dout, dt, dA_cumsum, seq_idx, D, B, dstates, dx, ddt, dD,
263
+ chunk_size, headdim, dstate,
264
+ batch, seqlen, nheads // ngroups,
265
+ x.stride(0), x.stride(1), x.stride(2), x.stride(3),
266
+ CB.stride(0), CB.stride(1), CB.stride(2), CB.stride(-1), CB.stride(-2),
267
+ dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),
268
+ dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),
269
+ dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
270
+ *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
271
+ D.stride(0) if D is not None else 0,
272
+ B.stride(0), B.stride(1), B.stride(2), B.stride(3),
273
+ dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3), dstates.stride(4),
274
+ dx.stride(0), dx.stride(1), dx.stride(2), dx.stride(3),
275
+ ddt.stride(0), ddt.stride(2), ddt.stride(1), ddt.stride(3),
276
+ dD_strides[1], dD_strides[2], dD_strides[3], dD_strides[0], dD_strides[4],
277
+ D is not None,
278
+ D.dim() == 2 if D is not None else True,
279
+ HAS_SEQ_IDX=seq_idx is not None,
280
+ BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
281
+ IS_TRITON_22=TRITON_22
282
+ )
283
+ if D is not None:
284
+ BLOCK_SIZE_actual = _chunk_scan_chunk_state_bwd_dx_kernel.best_config.kwargs["BLOCK_SIZE_M"]
285
+ n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual
286
+ dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)).to(dtype=D.dtype)
287
+ if D.dim() == 1:
288
+ dD = rearrange(dD, "h 1 -> h")
289
+ return dx, ddt.to(dtype=dt.dtype), dD
290
+
291
+
292
+ def _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, seq_idx=None, cu_seqlens=None, dt_softplus=False, dt_limit=(0.0, float("inf"))):
293
+ batch, seqlen, nheads, headdim = x.shape
294
+ _, _, ngroups, dstate = B.shape
295
+ assert nheads % ngroups == 0
296
+ assert B.shape == (batch, seqlen, ngroups, dstate)
297
+ assert x.shape == (batch, seqlen, nheads, headdim)
298
+ assert dt.shape == (batch, seqlen, nheads)
299
+ assert A.shape == (nheads,)
300
+ assert C.shape == B.shape
301
+ if z is not None:
302
+ assert z.shape == x.shape
303
+ if D is not None:
304
+ assert D.shape == (nheads, headdim) or D.shape == (nheads,)
305
+ if seq_idx is not None:
306
+ assert seq_idx.shape == (batch, seqlen)
307
+ if B.stride(-1) != 1:
308
+ B = B.contiguous()
309
+ if C.stride(-1) != 1:
310
+ C = C.contiguous()
311
+ if x.stride(-1) != 1 and x.stride(1) != 1: # Either M or K dimension should be contiguous
312
+ x = x.contiguous()
313
+ if z is not None and z.stride(-1) != 1 and z.stride(1) != 1: # Either M or K dimension should be contiguous
314
+ z = z.contiguous()
315
+ if D is not None and D.stride(-1) != 1:
316
+ D = D.contiguous()
317
+ if initial_states is not None:
318
+ assert initial_states.shape == (batch, nheads, headdim, dstate)
319
+ # # (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, nheads, chunk_size, chunk_size)
320
+ # dA_cumsum_tmp0, dt_tmp0 = _chunk_cumsum_fwd(dt[:, :147], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus)
321
+ # dA_cumsum_tmp1, dt_tmp1 = _chunk_cumsum_fwd(dt[:, 147:], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus)
322
+ # dA_cumsum_tmp2, dt_tmp2 = _chunk_cumsum_fwd(dt[:, 147:256], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus)
323
+ dA_cumsum, dt = _chunk_cumsum_fwd(dt, A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit)
324
+ states = _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=seq_idx, states_in_fp32=True)
325
+ # states_tmp0 = _chunk_state_fwd(B[:, :147], x[:, :147], dt_tmp0, dA_cumsum_tmp0, states_in_fp32=True)
326
+ # states_tmp1 = _chunk_state_fwd(B[:, 147:], x[:, 147:], dt_tmp1, dA_cumsum_tmp1, states_in_fp32=True)
327
+ # states_tmp2 = _chunk_state_fwd(B[:, 147:256], x[:, 147:256], dt_tmp2, dA_cumsum_tmp2, states_in_fp32=True)
328
+ states, final_states = _state_passing_fwd(rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1],
329
+ initial_states=rearrange(initial_states, "... p n -> ... (p n)") if initial_states is not None else None,
330
+ seq_idx=seq_idx, chunk_size=chunk_size, out_dtype=C.dtype)
331
+ states, final_states = [rearrange(t, "... (p n) -> ... p n", n=dstate) for t in [states, final_states]]
332
+ # states_tmp0 = rearrange(_state_passing_fwd(rearrange(states_tmp0, "... p n -> ... (p n)"), dA_cumsum_tmp0[:, :, :, -1], chunk_size=chunk_size), "... (p n) -> ... p n", n=dstate)
333
+ # states_tmp1 = rearrange(_state_passing_fwd(rearrange(states_tmp1, "... p n -> ... (p n)"), dA_cumsum_tmp1[:, :, :, -1], chunk_size=chunk_size), "... (p n) -> ... p n", n=dstate)
334
+ CB = _bmm_chunk_fwd(C, B, chunk_size, seq_idx=seq_idx, output_dtype=torch.float32)
335
+ out, out_x = _chunk_scan_fwd(CB, x, dt, dA_cumsum, C, states, D=D, z=z, seq_idx=seq_idx)
336
+ if cu_seqlens is None:
337
+ return out, out_x, dt, dA_cumsum, states, final_states
338
+ else:
339
+ assert batch == 1, "passing cu_seqlens to get the varlen states is only supported if batch dimension is 1"
340
+ varlen_states = chunk_state_varlen(B.squeeze(0), x.squeeze(0), dt.squeeze(0), dA_cumsum.squeeze(0),
341
+ cu_seqlens, states.squeeze(0))
342
+ return out, out_x, dt, dA_cumsum, states, final_states, varlen_states
343
+
344
+
345
+ def _mamba_chunk_scan_combined_bwd(dout, x, dt, A, B, C, out, chunk_size, D=None, z=None,
346
+ dt_bias=None, initial_states=None, dfinal_states=None, seq_idx=None, dt_softplus=False,
347
+ dt_limit=(0.0, float("inf")),
348
+ dx=None, ddt=None, dB=None, dC=None, dz=None, recompute_output=False):
349
+ if dout.stride(-1) != 1:
350
+ dout = dout.contiguous()
351
+ batch, seqlen, nheads, headdim = x.shape
352
+ nchunks = math.ceil(seqlen / chunk_size)
353
+ _, _, ngroups, dstate = B.shape
354
+ assert dout.shape == (batch, seqlen, nheads, headdim)
355
+ assert dt.shape == (batch, seqlen, nheads)
356
+ assert A.shape == (nheads,)
357
+ assert nheads % ngroups == 0
358
+ assert B.shape == (batch, seqlen, ngroups, dstate)
359
+ assert C.shape == B.shape
360
+ assert out.shape == x.shape
361
+ if initial_states is not None:
362
+ assert initial_states.shape == (batch, nheads, headdim, dstate)
363
+ if seq_idx is not None:
364
+ assert seq_idx.shape == (batch, seqlen)
365
+ if dx is not None:
366
+ assert dx.shape == x.shape
367
+ if dB is not None:
368
+ assert dB.shape == B.shape
369
+ dB_given = dB
370
+ else:
371
+ dB_given = torch.empty_like(B)
372
+ if dC is not None:
373
+ assert dC.shape == C.shape
374
+ dC_given = dC
375
+ else:
376
+ dC_given = torch.empty_like(C)
377
+ if dz is not None:
378
+ assert z is not None
379
+ assert dz.shape == z.shape
380
+ if ddt is not None:
381
+ assert ddt.shape == dt.shape
382
+ ddt_given = ddt
383
+ else:
384
+ ddt_given = torch.empty_like(dt)
385
+ # TD: For some reason Triton (2.1.0 and 2.2.0) errors with
386
+ # "[CUDA]: invalid device context" (e.g. during varlne test), and cloning makes it work. Idk why.
387
+ dt_in = dt.clone()
388
+ dA_cumsum, dt = _chunk_cumsum_fwd(dt_in, A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus,
389
+ dt_limit=dt_limit)
390
+ CB = _bmm_chunk_fwd(C, B, chunk_size, seq_idx=seq_idx, output_dtype=torch.float32)
391
+ states = _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=seq_idx, states_in_fp32=True)
392
+ states, _ = _state_passing_fwd(rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1],
393
+ initial_states=rearrange(initial_states, "... p n -> ... (p n)") if initial_states is not None else None,
394
+ seq_idx=seq_idx, chunk_size=chunk_size)
395
+ states = rearrange(states, "... (p n) -> ... p n", n=dstate)
396
+ if z is not None:
397
+ dz, dout, dD, *rest = _chunk_scan_bwd_dz(x, z, out, dout, chunk_size=chunk_size, has_ddAcs=False, D=D, dz=dz, recompute_output=recompute_output)
398
+ outz = rest[0] if recompute_output else out
399
+ else:
400
+ dz = None
401
+ outz = out
402
+ dstates = _chunk_scan_bwd_dstates(C, dA_cumsum, dout, seq_idx=seq_idx, dtype=states.dtype)
403
+ # dstates has length nchunks, containing the gradient to initial states at index 0 and
404
+ # gradient to the states of chunk (nchunks - 2) at index (nchunks - 1)
405
+ # Do computation in fp32 but convert dstates and states to fp16/bf16 since dstates and states
406
+ # will be used in matmul in the next kernels.
407
+ dstates, ddA_chunk_cumsum, dinitial_states, states = _state_passing_bwd(
408
+ rearrange(states, "... p n -> ... (p n)"),
409
+ dA_cumsum[:, :, :, -1],
410
+ rearrange(dstates, "... p n -> ... (p n)"),
411
+ dfinal_states=rearrange(dfinal_states, "... p n -> ... (p n)") if dfinal_states is not None else None,
412
+ seq_idx=seq_idx,
413
+ has_initial_states=initial_states is not None,
414
+ dstates_dtype=x.dtype,
415
+ states_dtype=x.dtype,
416
+ chunk_size=chunk_size,
417
+ )
418
+ # dstates has length nchunks, containing the gradient to states of chunk 0 at index 0 and
419
+ # gradient to the final states at index (nchunks - 1)
420
+ # states has length nchunks, containing the initial states at index 0 and the state for chunk (nchunks - 2) at index (nchunks - 1)
421
+ # The final states is not stored.
422
+ states = rearrange(states, "... (p n) -> ... p n", n=dstate)
423
+ dstates = rearrange(dstates, "... (p n) -> ... p n", n=dstate)
424
+ dinitial_states = rearrange(dinitial_states, "... (p n) -> ... p n", n=dstate) if dinitial_states is not None else None
425
+ dx, ddt, dD_from_x = _chunk_scan_chunk_state_bwd_dx(x, dt, dA_cumsum, B, CB, dout, dstates, D=D, seq_idx=seq_idx, dx=dx)
426
+ # dB = _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, seq_idx=seq_idx, ngroups=ngroups)
427
+ dB, ddA_next = _chunk_state_bwd_db(x, dt, dA_cumsum, dstates, seq_idx=seq_idx, B=B, ngroups=ngroups)
428
+ # dC = _chunk_scan_bwd_dC(states[:, :-1].to(x.dtype), dA_cumsum, dout, seq_idx=seq_idx, ngroups=ngroups)
429
+ dC, ddA_cumsum_prev = _chunk_scan_bwd_dC(states.to(x.dtype), dA_cumsum, dout, seq_idx=seq_idx, C=C, ngroups=ngroups)
430
+ # Computing ddA with the dcb kernel is much slower, so we're not using it for now
431
+ dCB = _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, seq_idx=seq_idx, ngroups=ngroups)
432
+ # dCB, ddA_tmp = _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, seq_idx=seq_idx, CB=CB, ngroups=ngroups)
433
+ dCB = dCB.to(CB.dtype)
434
+ _bmm_chunk_bwd(C, dCB, residual=dB, out=dB_given)
435
+ _bmm_chunk_bwd(B, rearrange(dCB, "... l s -> ... s l"), residual=dC, out=dC_given)
436
+ # If we have z, then dout_x is recomputed in fp32 so dD = (dout_x * x).sum() is more accurate
437
+ # than dD_from_x = (dout_x * x).sum() where dout_x is in fp16/bf16
438
+ if z is None:
439
+ dD = dD_from_x
440
+ # Formula for ddA_cumsum, assuming out is the output of the forward pass before adding x * D.
441
+ # ddA_cumsum = torch.einsum("bclhp,bclhp->bhcl", out.float(), dout.float()) - ddt * dt
442
+ # However, this is numerically unstable: when we do the reverse cumsum on ddA_cumsum, there might
443
+ # be a lot of underflow.
444
+
445
+ # This is already done as part of bwd_dC kernel
446
+ # ddA_cumsum_prev = _chunk_scan_bwd_ddAcs_prev(states[:, :-1], C, dout, dA_cumsum, seq_idx=seq_idx)
447
+ ddA_cumsum_prev[..., -1] += ddA_chunk_cumsum
448
+ ddA_prev = ddA_cumsum_prev.flip([-1]).cumsum(dim=-1).flip([-1])
449
+ # This is already done as part of bwd_dB kernel
450
+ # ddA_next = _chunk_state_bwd_ddAcs_stable(B, x, dt, dA_cumsum, dstates, seq_idx=seq_idx)
451
+ # We don't need to pass in seq_idx because CB also zeros out entries where seq_idx[i] != seq_idx[j]
452
+ ddA = _chunk_scan_bwd_ddAcs_stable(x, dt, dA_cumsum, dout, CB)
453
+ ddA += ddA_next + ddA_prev
454
+
455
+ ddt_given, dA, ddt_bias = _chunk_cumsum_bwd(ddA, ddt, dt_in, A, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit, ddt=ddt_given)
456
+
457
+ # These 2 lines are just to test ddt and dA being computed by old code
458
+ # _, dA = selective_scan_bwd(dout, x, dt, A, B, C, D=D.float(), z=z)
459
+ # ddt_given.copy_(ddt)
460
+
461
+ return_vals = (dx, ddt_given, dA, dB_given, dC_given, dD, dz, ddt_bias, dinitial_states)
462
+ return return_vals if not recompute_output else (*return_vals, outz)
463
+
464
+
465
+ def selective_scan_bwd(dout, x, dt, A, B, C, D=None, z=None):
466
+ """
467
+ Argument:
468
+ dout: (batch, seqlen, nheads, headdim)
469
+ x: (batch, seqlen, nheads, headdim)
470
+ dt: (batch, nheads, nchunks, chunk_size) or (batch, nheads, headdim, nchunks, chunk_size)
471
+ A: (nheads) or (dim, dstate)
472
+ B: (batch, seqlen, ngroups, dstate)
473
+ C: (batch, seqlen, ngroups, dstate)
474
+ D: (nheads, headdim) or (nheads,)
475
+ z: (batch, seqlen, nheads, headdim)
476
+ Return:
477
+ out: (batch, seqlen, nheads, headdim)
478
+ """
479
+ import selective_scan
480
+
481
+ batch, seqlen, nheads, headdim = x.shape
482
+ chunk_size = dt.shape[-1]
483
+ _, _, ngroups, dstate = B.shape
484
+ assert nheads % ngroups == 0
485
+ x = rearrange(x, "b l h p -> b (h p) l")
486
+ squeeze_dt = dt.dim() == 4
487
+ if dt.dim() == 4:
488
+ dt = repeat(dt, "b h c l -> b h p c l", p=headdim)
489
+ dt = rearrange(dt, "b h p c l -> b (h p) (c l)", p=headdim)
490
+ squeeze_A = A.dim() == 1
491
+ if A.dim() == 1:
492
+ A = repeat(A, "h -> (h p) n", p=headdim, n=dstate).to(dtype=torch.float32)
493
+ else:
494
+ A = A.to(dtype=torch.float32)
495
+ B = rearrange(B, "b l g n -> b g n l")
496
+ C = rearrange(C, "b l g n -> b g n l")
497
+ if D is not None:
498
+ if D.dim() == 2:
499
+ D = rearrange(D, "h p -> (h p)")
500
+ else:
501
+ D = repeat(D, "h -> (h p)", p=headdim)
502
+ if z is not None:
503
+ z = rearrange(z, "b l h p -> b (h p) l")
504
+
505
+ if x.stride(-1) != 1:
506
+ x = x.contiguous()
507
+ if dt.stride(-1) != 1:
508
+ dt = dt.contiguous()
509
+ if D is not None:
510
+ D = D.contiguous()
511
+ if B.stride(-1) != 1:
512
+ B = B.contiguous()
513
+ if C.stride(-1) != 1:
514
+ C = C.contiguous()
515
+ if z is not None and z.stride(-1) != 1:
516
+ z = z.contiguous()
517
+ _, intermediate, *rest = selective_scan.fwd(x, dt.to(dtype=x.dtype), A, B, C, D, z, None, False)
518
+ if z is not None:
519
+ out = rest[0]
520
+ else:
521
+ out = None
522
+
523
+ dout = rearrange(dout, "b l h p -> b (h p) l")
524
+
525
+ if dout.stride(-1) != 1:
526
+ dout = dout.contiguous()
527
+ # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
528
+ # backward of selective_scan with the backward of chunk).
529
+ # Here we just pass in None and dz will be allocated in the C++ code.
530
+ _, ddt, dA, *rest = selective_scan.bwd(
531
+ x, dt.to(dtype=x.dtype), A, B, C, D, z, None, dout, intermediate, out, None, False,
532
+ False # option to recompute out_z, not used here
533
+ )
534
+ ddt = rearrange(ddt, "b (h p) (c l) -> b h p c l", p=headdim, l=chunk_size)
535
+ if squeeze_dt:
536
+ ddt = ddt.float().sum(dim=2)
537
+ if squeeze_A:
538
+ dA = rearrange(dA, "(h p) n -> h p n", p=headdim).sum(dim=(1, 2))
539
+ return ddt, dA
540
+
541
+
542
+ class MambaChunkScanCombinedFn(torch.autograd.Function):
543
+
544
+ @staticmethod
545
+ def forward(ctx, x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, seq_idx=None, cu_seqlens=None, dt_softplus=False, dt_limit=(0.0, float("inf")), return_final_states=False, return_varlen_states=False):
546
+ ctx.dt_dtype = dt.dtype
547
+ if not return_varlen_states:
548
+ cu_seqlens = None
549
+ else:
550
+ assert cu_seqlens is not None, "cu_seqlens must be provided if return_varlen_states is True"
551
+ out, out_x, dt_out, dA_cumsum, states, final_states, *rest = _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, cu_seqlens=cu_seqlens, dt_softplus=dt_softplus, dt_limit=dt_limit)
552
+ ctx.save_for_backward(out if z is None else out_x, x, dt, dA_cumsum, A, B, C, D, z, dt_bias, initial_states, seq_idx)
553
+ ctx.dt_softplus = dt_softplus
554
+ ctx.chunk_size = chunk_size
555
+ ctx.dt_limit = dt_limit
556
+ ctx.return_final_states = return_final_states
557
+ ctx.return_varlen_states = return_varlen_states
558
+ if not return_varlen_states:
559
+ return out if not return_final_states else (out, final_states)
560
+ else:
561
+ varlen_states = rest[0]
562
+ return (out, varlen_states) if not return_final_states else (out, final_states, varlen_states)
563
+
564
+ @staticmethod
565
+ def backward(ctx, dout, *args):
566
+ out, x, dt, dA_cumsum, A, B, C, D, z, dt_bias, initial_states, seq_idx = ctx.saved_tensors
567
+ assert not ctx.return_varlen_states, "return_varlen_states is not supported in backward"
568
+ dfinal_states = args[0] if ctx.return_final_states else None
569
+ dx, ddt, dA, dB, dC, dD, dz, ddt_bias, dinitial_states = _mamba_chunk_scan_combined_bwd(dout, x, dt, A, B, C, out, ctx.chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, dfinal_states=dfinal_states, seq_idx=seq_idx, dt_softplus=ctx.dt_softplus, dt_limit=ctx.dt_limit)
570
+ return dx, ddt, dA, dB, dC, None, dD, dz, ddt_bias, dinitial_states, None, None, None, None, None, None
571
+
572
+
573
+ def mamba_chunk_scan_combined(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, seq_idx=None, cu_seqlens=None, dt_softplus=False, dt_limit=(0.0, float("inf")), return_final_states=False, return_varlen_states=False):
574
+ """
575
+ Argument:
576
+ x: (batch, seqlen, nheads, headdim)
577
+ dt: (batch, seqlen, nheads)
578
+ A: (nheads)
579
+ B: (batch, seqlen, ngroups, dstate)
580
+ C: (batch, seqlen, ngroups, dstate)
581
+ chunk_size: int
582
+ D: (nheads, headdim) or (nheads,)
583
+ z: (batch, seqlen, nheads, headdim)
584
+ dt_bias: (nheads,)
585
+ initial_states: (batch, nheads, headdim, dstate)
586
+ seq_idx: (batch, seqlen)
587
+ cu_seqlens: (num_sequences + 1) or None, only used if return_varlen_states is True
588
+ dt_softplus: Whether to apply softplus to dt
589
+ Return:
590
+ out: (batch, seqlen, nheads, headdim)
591
+ """
592
+ return MambaChunkScanCombinedFn.apply(x, dt, A, B, C, chunk_size, D, z, dt_bias, initial_states, seq_idx, cu_seqlens, dt_softplus, dt_limit, return_final_states, return_varlen_states)
593
+
594
+
595
+ def mamba_chunk_scan(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, dt_softplus=False):
596
+ """
597
+ Argument:
598
+ x: (batch, seqlen, nheads, headdim)
599
+ dt: (batch, seqlen, nheads)
600
+ A: (nheads)
601
+ B: (batch, seqlen, ngroups, dstate)
602
+ C: (batch, seqlen, ngroups, dstate)
603
+ D: (nheads, headdim) or (nheads,)
604
+ z: (batch, seqlen, nheads, headdim)
605
+ dt_bias: (nheads,)
606
+ Return:
607
+ out: (batch, seqlen, nheads, headdim)
608
+ """
609
+ batch, seqlen, nheads, headdim = x.shape
610
+ dstate = B.shape[-1]
611
+ if seqlen % chunk_size != 0:
612
+ dt = F.pad(dt, (0, 0, 0, chunk_size - seqlen % chunk_size))
613
+ dt = rearrange(dt, "b (c l) h -> b h c l", l=chunk_size)
614
+ dt = dt.float() # We want high precision for this before cumsum
615
+ if dt_bias is not None:
616
+ dt = dt + rearrange(dt_bias, "h -> h 1 1")
617
+ if dt_softplus:
618
+ dt = F.softplus(dt)
619
+ dA = dt * rearrange(A, "h -> h 1 1")
620
+ dA = dt * rearrange(A, "h -> h 1 1")
621
+ dA_cumsum = torch.cumsum(dA, dim=-1)
622
+ # 1. Compute the state for each chunk
623
+ states = chunk_state(B, x, dt, dA_cumsum, states_in_fp32=True)
624
+ # 2. Pass the state to all the chunks by weighted cumsum.
625
+ states = rearrange(state_passing(rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1])[0],
626
+ "... (p n) -> ... p n", n=dstate)
627
+ # 3. Compute the output for each chunk
628
+ out = chunk_scan(B, C, x, dt, dA_cumsum, states, D=D, z=z)
629
+ return out
630
+
631
+
632
+ def ssd_chunk_scan_combined_ref(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, dt_softplus=False):
633
+ """
634
+ Argument:
635
+ x: (batch, seqlen, nheads, headdim)
636
+ dt: (batch, seqlen, nheads)
637
+ A: (nheads)
638
+ B: (batch, seqlen, ngroups, dstate)
639
+ C: (batch, seqlen, ngroups, dstate)
640
+ D: (nheads, headdim) or (nheads,)
641
+ z: (batch, seqlen, nheads, headdim)
642
+ dt_bias: (nheads,)
643
+ Return:
644
+ out: (batch, seqlen, nheads, headdim)
645
+ """
646
+ batch, seqlen, nheads, headdim = x.shape
647
+ dstate = B.shape[-1]
648
+ if seqlen % chunk_size != 0:
649
+ dt = F.pad(dt, (0, 0, 0, chunk_size - seqlen % chunk_size))
650
+ dt = rearrange(dt, "b (c l) h -> b h c l", l=chunk_size)
651
+ dt = dt.float() # We want high precision for this before cumsum
652
+ if dt_bias is not None:
653
+ dt = dt + rearrange(dt_bias, "h -> h 1 1")
654
+ if dt_softplus:
655
+ dt = F.softplus(dt)
656
+ dA = dt * rearrange(A, "h -> h 1 1")
657
+ dA_cumsum = torch.cumsum(dA, dim=-1)
658
+ # 1. Compute the state for each chunk
659
+ states = chunk_state_ref(B, x, dt, dA_cumsum)
660
+ states_dtype = states.dtype
661
+ if states.dtype not in [torch.float32, torch.float64]:
662
+ states = states.to(torch.float32)
663
+ # 2. Pass the state to all the chunks by weighted cumsum.
664
+ # state_passing_ref is much less numerically stable
665
+ states = rearrange(state_passing_ref(rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1])[0],
666
+ "... (p n) -> ... p n", n=dstate)
667
+ states = states.to(states_dtype)
668
+ # 3. Compute the output for each chunk
669
+ out = chunk_scan_ref(B, C, x, dt, dA_cumsum, states, D=D, z=z)
670
+ return out
671
+
672
+
673
+ def ssd_selective_scan(x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf"))):
674
+ """
675
+ Argument:
676
+ x: (batch, seqlen, nheads, headdim)
677
+ dt: (batch, seqlen, nheads) or (batch, seqlen, nheads, headdim)
678
+ A: (nheads) or (dim, dstate)
679
+ B: (batch, seqlen, ngroups, dstate)
680
+ C: (batch, seqlen, ngroups, dstate)
681
+ D: (nheads, headdim) or (nheads,)
682
+ z: (batch, seqlen, nheads, headdim)
683
+ dt_bias: (nheads,) or (nheads, headdim)
684
+ Return:
685
+ out: (batch, seqlen, nheads, headdim)
686
+ """
687
+ from ..selective_scan_interface import selective_scan_fn
688
+
689
+ batch, seqlen, nheads, headdim = x.shape
690
+ _, _, ngroups, dstate = B.shape
691
+ x = rearrange(x, "b l h p -> b (h p) l")
692
+ if dt.dim() == 3:
693
+ dt = repeat(dt, "b l h -> b l h p", p=headdim)
694
+ dt = rearrange(dt, "b l h p -> b (h p) l")
695
+ if A.dim() == 1:
696
+ A = repeat(A, "h -> (h p) n", p=headdim, n=dstate).to(dtype=torch.float32)
697
+ else:
698
+ A = A.to(dtype=torch.float32)
699
+ B = rearrange(B, "b l g n -> b g n l")
700
+ C = rearrange(C, "b l g n -> b g n l")
701
+ if D is not None:
702
+ if D.dim() == 2:
703
+ D = rearrange(D, "h p -> (h p)")
704
+ else:
705
+ D = repeat(D, "h -> (h p)", p=headdim)
706
+ if z is not None:
707
+ z = rearrange(z, "b l h p -> b (h p) l")
708
+ if dt_bias is not None:
709
+ if dt_bias.dim() == 1:
710
+ dt_bias = repeat(dt_bias, "h -> h p", p=headdim)
711
+ dt_bias = rearrange(dt_bias, "h p -> (h p)")
712
+ if dt_limit != (0.0, float("inf")):
713
+ if dt_bias is not None:
714
+ dt = dt + rearrange(dt_bias, "d -> d 1")
715
+ if dt_softplus:
716
+ dt = F.softplus(dt)
717
+ dt = dt.clamp(min=dt_limit[0], max=dt_limit[1]).to(x.dtype)
718
+ dt_bias = None
719
+ dt_softplus = None
720
+ out = selective_scan_fn(x, dt, A, B, C, D=D, z=z, delta_bias=dt_bias, delta_softplus=dt_softplus)
721
+ return rearrange(out, "b (h p) l -> b l h p", p=headdim)
722
+
723
+
724
+ def mamba_conv1d_scan_ref(xBC, conv1d_weight, conv1d_bias, dt, A, chunk_size, D=None, z=None,
725
+ dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf")),
726
+ activation="silu", headdim=None, ngroups=1):
727
+ """
728
+ Argument:
729
+ xBC: (batch, seqlen, dim + 2 * ngroups * dstate) where dim == nheads * headdim
730
+ conv1d_weight: (dim + 2 * ngroups * dstate, width)
731
+ conv1d_bias: (dim + 2 * ngroups * dstate,)
732
+ dt: (batch, seqlen, nheads) or (batch, seqlen, nheads, headdim)
733
+ A: (nheads)
734
+ D: (nheads, headdim) or (nheads,)
735
+ z: (batch, seqlen, dim)
736
+ dt_bias: (nheads) or (nheads, headdim)
737
+ headdim: if D is 1D and z is None, headdim must be passed in
738
+ Return:
739
+ out: (batch, seqlen, dim)
740
+ """
741
+ batch, seqlen, nheads = dt.shape[:3]
742
+ assert nheads % ngroups == 0
743
+ if z is not None:
744
+ dim = z.shape[-1]
745
+ assert dim % nheads == 0
746
+ headdim = dim // nheads
747
+ else:
748
+ if D.dim() == 1:
749
+ assert headdim is not None
750
+ else:
751
+ headdim = D.shape[1]
752
+ dim = nheads * headdim
753
+ xBC = rearrange(causal_conv1d_fn(rearrange(xBC, "b s d -> b d s"), conv1d_weight, conv1d_bias, activation=activation),
754
+ "b d s -> b s d")
755
+ dstate = (xBC.shape[-1] - dim) // ngroups // 2
756
+ x, B, C = torch.split(xBC, [dim, ngroups * dstate, ngroups * dstate], dim=-1)
757
+ x = rearrange(x, "b l (h p) -> b l h p", h=nheads)
758
+ B = rearrange(B, "b l (g n) -> b l g n", g=ngroups)
759
+ C = rearrange(C, "b l (g n) -> b l g n", g=ngroups)
760
+ z = rearrange(z, "b l (h p) -> b l h p", h=nheads) if z is not None else None
761
+ out = ssd_selective_scan(x, dt.to(x.dtype), A, B, C, D=D.float(), z=z, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit)
762
+ return rearrange(out, "b s h p -> b s (h p)")
763
+
764
+
765
+ class MambaSplitConv1dScanCombinedFn(torch.autograd.Function):
766
+
767
+ @staticmethod
768
+ @custom_fwd
769
+ def forward(ctx, zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states=None, seq_idx=None, dt_limit=(0.0, float("inf")), return_final_states=False, activation="silu",
770
+ rmsnorm_weight=None, rmsnorm_eps=1e-6, outproj_weight=None, outproj_bias=None, headdim=None,
771
+ ngroups=1, norm_before_gate=True):
772
+ assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
773
+ assert activation in [None, "silu", "swish"]
774
+ if D.dim() == 1:
775
+ assert headdim is not None
776
+ nheads, = D.shape
777
+ else:
778
+ nheads, headdim = D.shape
779
+ batch, seqlen, _ = zxbcdt.shape
780
+ dim = nheads * headdim
781
+ assert nheads % ngroups == 0
782
+ dstate = (conv1d_weight.shape[0] - dim) // ngroups // 2
783
+ d_nonssm = (zxbcdt.shape[-1] - 2 * dim - 2 * ngroups * dstate - nheads) // 2
784
+ assert d_nonssm >= 0
785
+ assert zxbcdt.shape == (batch, seqlen, 2 * d_nonssm + 2 * dim + 2 * ngroups * dstate + nheads)
786
+ assert dt_bias.shape == (nheads,)
787
+ assert A.shape == (nheads,)
788
+ zx0, z, xBC, dt = torch.split(zxbcdt, [2 * d_nonssm, dim, dim + ngroups * dstate * 2, nheads], dim=-1)
789
+ seq_idx = seq_idx.contiguous() if seq_idx is not None else None
790
+ xBC_conv = rearrange(
791
+ causal_conv1d_cuda.causal_conv1d_fwd(rearrange_and_update_stride(xBC, "b s d -> b d s"),
792
+ conv1d_weight, conv1d_bias, seq_idx, None, None, activation in ["silu", "swish"]),
793
+ "b d s -> b s d"
794
+ )
795
+ x, B, C = torch.split(xBC_conv, [dim, ngroups * dstate, ngroups * dstate], dim=-1)
796
+ x = rearrange(x, "b l (h p) -> b l h p", h=nheads)
797
+ B = rearrange(B, "b l (g n) -> b l g n", g=ngroups)
798
+ C = rearrange(C, "b l (g n) -> b l g n", g=ngroups)
799
+ z = rearrange(z, "b l (h p) -> b l h p", h=nheads) if z is not None else None
800
+ if rmsnorm_weight is None:
801
+ out, out_x, dt_out, dA_cumsum, states, final_states = _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size=chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, dt_softplus=True, dt_limit=dt_limit)
802
+ out = rearrange(out, "b s h p -> b s (h p)")
803
+ rstd = None
804
+ if d_nonssm > 0:
805
+ out = torch.cat([_swiglu_fwd(zx0), out], dim=-1)
806
+ else:
807
+ out_x, _, dt_out, dA_cumsum, states, final_states = _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size=chunk_size, D=D, z=None, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, dt_softplus=True, dt_limit=dt_limit)
808
+ # reshape input data into 2D tensor
809
+ x_rms = rearrange(out_x, "b s h p -> (b s) (h p)")
810
+ z_rms = rearrange(z, "b s h p -> (b s) (h p)")
811
+ rmsnorm_weight = rmsnorm_weight.contiguous()
812
+ if d_nonssm == 0:
813
+ out = None
814
+ else:
815
+ out01 = torch.empty((batch, seqlen, d_nonssm + dim), dtype=x_rms.dtype, device=x_rms.device)
816
+ out = rearrange(out01[..., d_nonssm:], "b s d -> (b s) d")
817
+ _swiglu_fwd(zx0, out=out01[..., :d_nonssm])
818
+ out, _, rstd = _layer_norm_fwd(x_rms, rmsnorm_weight, None, rmsnorm_eps, z_rms, out=out,
819
+ group_size=dim // ngroups,
820
+ norm_before_gate=norm_before_gate, is_rms_norm=True)
821
+ if d_nonssm == 0:
822
+ out = rearrange(out, "(b s) d -> b s d", b=batch)
823
+ else:
824
+ out = out01
825
+ ctx.outproj_weight_dtype = outproj_weight.dtype if outproj_weight is not None else None
826
+ if outproj_weight is not None:
827
+ if torch.is_autocast_enabled():
828
+ dtype = torch.get_autocast_gpu_dtype()
829
+ out, outproj_weight = out.to(dtype), outproj_weight.to(dtype)
830
+ outproj_bias = outproj_bias.to(dtype) if outproj_bias is not None else None
831
+ out = F.linear(out, outproj_weight, outproj_bias)
832
+ else:
833
+ assert outproj_bias is None
834
+ ctx.save_for_backward(zxbcdt, conv1d_weight, conv1d_bias,
835
+ out_x, A, D, dt_bias, initial_states, seq_idx, rmsnorm_weight, rstd, outproj_weight, outproj_bias)
836
+ ctx.dt_limit = dt_limit
837
+ ctx.return_final_states = return_final_states
838
+ ctx.activation = activation
839
+ ctx.rmsnorm_eps = rmsnorm_eps
840
+ ctx.norm_before_gate = norm_before_gate
841
+ ctx.chunk_size = chunk_size
842
+ ctx.headdim = headdim
843
+ ctx.ngroups = ngroups
844
+ return out if not return_final_states else (out, final_states)
845
+
846
+ @staticmethod
847
+ @custom_bwd
848
+ def backward(ctx, dout, *args):
849
+ assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
850
+ zxbcdt, conv1d_weight, conv1d_bias, out, A, D, dt_bias, initial_states, seq_idx, rmsnorm_weight, rstd, outproj_weight, outproj_bias = ctx.saved_tensors
851
+ dfinal_states = args[0] if ctx.return_final_states else None
852
+ headdim = ctx.headdim
853
+ nheads = D.shape[0]
854
+ dim = nheads * headdim
855
+ assert nheads % ctx.ngroups == 0
856
+ dstate = (conv1d_weight.shape[0] - dim) // ctx.ngroups // 2
857
+ d_nonssm = (zxbcdt.shape[-1] - 2 * dim - 2 * ctx.ngroups * dstate - nheads) // 2
858
+ assert d_nonssm >= 0
859
+ recompute_output = outproj_weight is not None
860
+ if recompute_output:
861
+ out_recompute = torch.empty(*out.shape[:2], d_nonssm + dim, device=out.device, dtype=out.dtype)
862
+ out0_recompute, out1_recompute = out_recompute.split([d_nonssm, dim], dim=-1)
863
+ zx0, z, xBC, dt = torch.split(zxbcdt, [2 * d_nonssm, dim, dim + 2 * ctx.ngroups * dstate, nheads], dim=-1)
864
+ # Recompute x, B, C
865
+ xBC_conv = rearrange(
866
+ causal_conv1d_cuda.causal_conv1d_fwd(rearrange_and_update_stride(xBC, "b s d -> b d s"),
867
+ conv1d_weight, conv1d_bias, seq_idx, None, None, ctx.activation in ["silu", "swish"]),
868
+ "b d s -> b s d"
869
+ )
870
+ x, B, C = torch.split(xBC_conv, [dim, ctx.ngroups * dstate, ctx.ngroups * dstate], dim=-1)
871
+ x = rearrange(x, "b l (h p) -> b l h p", h=nheads)
872
+ B = rearrange(B, "b l (g n) -> b l g n", g=ctx.ngroups)
873
+ C = rearrange(C, "b l (g n) -> b l g n", g=ctx.ngroups)
874
+ dzxbcdt = torch.empty_like(zxbcdt)
875
+ dzx0, dz, dxBC_given, ddt_given = torch.split(dzxbcdt, [2 * d_nonssm, dim, dim + 2 * ctx.ngroups * dstate, nheads], dim=-1)
876
+ dxBC = torch.empty_like(xBC)
877
+ dx, dB, dC = torch.split(dxBC, [dim, ctx.ngroups * dstate, ctx.ngroups * dstate], dim=-1)
878
+ z = rearrange(z, "b l (h p) -> b l h p", h=nheads)
879
+ dx = rearrange(dx, "b l (h p) -> b l h p", h=nheads)
880
+ dB = rearrange(dB, "b l (g n) -> b l g n", g=ctx.ngroups)
881
+ dC = rearrange(dC, "b l (g n) -> b l g n", g=ctx.ngroups)
882
+ if outproj_weight is not None:
883
+ dout_og = dout
884
+ dout = F.linear(dout, outproj_weight.t())
885
+ if d_nonssm > 0:
886
+ dout0, dout = dout.split([d_nonssm, dim], dim=-1)
887
+ _swiglu_bwd(zx0, dout0, dxy=dzx0, recompute_output=True, out=out0_recompute)
888
+ dout = rearrange(dout, "b s (h p) -> b s h p", p=headdim)
889
+ if rmsnorm_weight is None:
890
+ dz = rearrange(dz, "b l (h p) -> b l h p", h=nheads)
891
+ dx, ddt, dA, dB, dC, dD, dz, ddt_bias, dinitial_states, *rest = _mamba_chunk_scan_combined_bwd(
892
+ dout, x, dt, A, B, C, out, ctx.chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, dfinal_states=dfinal_states, seq_idx=seq_idx, dt_softplus=True, dt_limit=ctx.dt_limit, dx=dx, ddt=ddt_given, dB=dB, dC=dC, dz=dz, recompute_output=recompute_output
893
+ )
894
+ out_for_linear = rearrange(rest[0], "b s h p -> b s (h p)") if recompute_output else None
895
+ drmsnorm_weight = None
896
+ else:
897
+ batch = dout.shape[0]
898
+ dy_rms = rearrange(dout, "b s h p -> (b s) (h p)")
899
+ dz = rearrange(dz, "b l d -> (b l) d")
900
+ x_rms = rearrange(out, "b s h p -> (b s) (h p)")
901
+ z_rms = rearrange(z, "b s h p -> (b s) (h p)")
902
+ out1_recompute = rearrange(out1_recompute, "b s d -> (b s) d") if recompute_output else None
903
+ dout, drmsnorm_weight, _, dz, *rest = _layer_norm_bwd(dy_rms, x_rms, rmsnorm_weight, None, ctx.rmsnorm_eps, None, rstd, z_rms, group_size=dim//ctx.ngroups, norm_before_gate=ctx.norm_before_gate, is_rms_norm=True, recompute_output=recompute_output, dz=dz, out=out1_recompute if recompute_output else None)
904
+ out_for_linear = out_recompute if recompute_output else None
905
+ dout = rearrange(dout, "(b s) (h p) -> b s h p", b=batch, p=headdim)
906
+ dx, ddt, dA, dB, dC, dD, _, ddt_bias, dinitial_states = _mamba_chunk_scan_combined_bwd(
907
+ dout, x, dt, A, B, C, out, ctx.chunk_size, D=D, z=None, dt_bias=dt_bias, initial_states=initial_states, dfinal_states=dfinal_states, seq_idx=seq_idx, dt_softplus=True, dt_limit=ctx.dt_limit, dx=dx, ddt=ddt_given, dB=dB, dC=dC
908
+ )
909
+
910
+ if outproj_weight is not None:
911
+ doutproj_weight = torch.einsum("bso,bsd->od", dout_og, out_for_linear)
912
+ doutproj_bias = dout_og.sum(dim=(0, 1)) if outproj_bias is not None else None
913
+ else:
914
+ doutproj_weight, doutproj_bias = None, None
915
+ dxBC_given = rearrange(dxBC_given, "b s d -> b d s")
916
+ dxBC_given_update, dweight, dbias, *_ = causal_conv1d_cuda.causal_conv1d_bwd(
917
+ rearrange_and_update_stride(xBC, "b s d -> b d s"), conv1d_weight, conv1d_bias,
918
+ rearrange(dxBC, "b s d -> b d s"), seq_idx, None, None, rearrange_and_update_stride(dxBC_given), False, ctx.activation in ["silu", "swish"]
919
+ )
920
+ if dxBC_given.stride() != dxBC_given_update.stride():
921
+ dxBC_given.copy_(dxBC_given_update)
922
+ else:
923
+ dxBC_given = dxBC_given_update
924
+ dxBC_given = rearrange(dxBC_given, "b d s -> b s d")
925
+ return dzxbcdt, dweight, dbias, ddt_bias, dA, dD, None, dinitial_states, None, None, None, None, drmsnorm_weight, None, doutproj_weight, doutproj_bias, None, None, None
926
+
927
+
928
+ def mamba_split_conv1d_scan_combined(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states=None, seq_idx=None, dt_limit=(0.0, float("inf")), return_final_states=False, activation="silu", rmsnorm_weight=None, rmsnorm_eps=1e-6, outproj_weight=None, outproj_bias=None, headdim=None, ngroups=1, norm_before_gate=True):
929
+ """
930
+ Argument:
931
+ zxbcdt: (batch, seqlen, 2 * dim + 2 * ngroups * dstate + nheads) where dim == nheads * headdim
932
+ conv1d_weight: (dim + 2 * ngroups * dstate, width)
933
+ conv1d_bias: (dim + 2 * ngroups * dstate,)
934
+ dt_bias: (nheads,)
935
+ A: (nheads)
936
+ D: (nheads, headdim) or (nheads,)
937
+ initial_states: (batch, nheads, headdim, dstate)
938
+ seq_idx: (batch, seqlen), int32
939
+ rmsnorm_weight: (dim,)
940
+ outproj_weight: (out_dim, dim)
941
+ outproj_bias: (out_dim,)
942
+ headdim: if D is 1D, headdim must be passed in
943
+ norm_before_gate: if True, we do RMSNorm(x) * F.silu(z). If False, we do RMSNorm(x * F.silu(z))
944
+ Return:
945
+ out: (batch, seqlen, dim)
946
+ """
947
+ return MambaSplitConv1dScanCombinedFn.apply(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states, seq_idx, dt_limit, return_final_states, activation, rmsnorm_weight, rmsnorm_eps, outproj_weight, outproj_bias, headdim, ngroups, norm_before_gate)
948
+
949
+
950
+ def mamba_split_conv1d_scan_ref(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, dt_limit=(0.0, float("inf")), activation="silu", rmsnorm_weight=None, rmsnorm_eps=1e-6, outproj_weight=None, outproj_bias=None, headdim=None, ngroups=1, norm_before_gate=True):
951
+ """
952
+ Argument:
953
+ zxbcdt: (batch, seqlen, 2 * dim + 2 * ngroups * dstate + nheads) where dim == nheads * headdim
954
+ conv1d_weight: (dim + 2 * ngroups * dstate, width)
955
+ conv1d_bias: (dim + 2 * ngroups * dstate,)
956
+ dt_bias: (nheads,)
957
+ A: (nheads)
958
+ D: (nheads, headdim) or (nheads,)
959
+ rmsnorm_weight: (dim,)
960
+ outproj_weight: (out_dim, dim)
961
+ outproj_bias: (out_dim,)
962
+ headdim: if D is 1D, headdim must be passed in
963
+ norm_before_gate: if True, we do RMSNorm(x) * F.silu(z). If False, we do RMSNorm(x * F.silu(z))
964
+ Return:
965
+ out: (batch, seqlen, dim)
966
+ """
967
+ if D.dim() == 1:
968
+ assert headdim is not None
969
+ nheads, = D.shape
970
+ else:
971
+ nheads, headdim = D.shape
972
+ assert nheads % ngroups == 0
973
+ batch, seqlen, _ = zxbcdt.shape
974
+ dim = nheads * headdim
975
+ dstate = (zxbcdt.shape[-1] - 2 * dim - nheads) // ngroups // 2
976
+ assert zxbcdt.shape == (batch, seqlen, 2 * dim + 2 * ngroups * dstate + nheads)
977
+ assert dt_bias.shape == (nheads,)
978
+ assert A.shape == (nheads,)
979
+ if rmsnorm_weight is not None:
980
+ assert rmsnorm_weight.shape == (dim,)
981
+ z, xBC, dt = torch.split(zxbcdt, [dim, dim + 2 * ngroups * dstate, nheads], dim=-1)
982
+ xBC = rearrange(causal_conv1d_fn(rearrange(xBC, "b s d -> b d s"), conv1d_weight, conv1d_bias, activation=activation),
983
+ "b d s -> b s d")
984
+ x, B, C = torch.split(xBC, [dim, ngroups * dstate, ngroups * dstate], dim=-1)
985
+ x = rearrange(x, "b l (h p) -> b l h p", h=nheads)
986
+ B = rearrange(B, "b l (g n) -> b l g n", g=ngroups)
987
+ C = rearrange(C, "b l (g n) -> b l g n", g=ngroups)
988
+ z = rearrange(z, "b l (h p) -> b l h p", h=nheads)
989
+ out = ssd_selective_scan(x, dt.to(x.dtype), A, B, C, D=D.float(),
990
+ z=z if rmsnorm_weight is None else None, dt_bias=dt_bias, dt_softplus=True, dt_limit=dt_limit)
991
+ out = rearrange(out, "b s h p -> b s (h p)")
992
+ if rmsnorm_weight is not None:
993
+ out = rmsnorm_fn(out, rmsnorm_weight, None, z=rearrange(z, "b l h p -> b l (h p)"), eps=rmsnorm_eps,
994
+ norm_before_gate=norm_before_gate)
995
+ if outproj_weight is not None:
996
+ out = F.linear(out, outproj_weight, outproj_bias)
997
+ return out
998
+
build/torch210-cxx11-cu126-x86_64-linux/ops/triton/ssd_state_passing.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao, Albert Gu.
2
+
3
+ """We want triton==2.1.0 or 2.2.0 for this
4
+ """
5
+
6
+ import math
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+ import triton
11
+ import triton.language as tl
12
+
13
+ from einops import rearrange, repeat
14
+
15
+
16
+ @triton.autotune(
17
+ configs=[
18
+ triton.Config({'BLOCK_SIZE': 64}),
19
+ triton.Config({'BLOCK_SIZE': 128}),
20
+ triton.Config({'BLOCK_SIZE': 256}),
21
+ triton.Config({'BLOCK_SIZE': 512}),
22
+ triton.Config({'BLOCK_SIZE': 1024}),
23
+ triton.Config({'BLOCK_SIZE': 2048}),
24
+ ],
25
+ key=['dim'],
26
+ )
27
+ @triton.jit
28
+ def _state_passing_fwd_kernel(
29
+ # Pointers to matrices
30
+ states_ptr, out_ptr, final_states_ptr, dA_cs_ptr, initstates_ptr, seq_idx_ptr,
31
+ # Matrix dimensions
32
+ dim, nchunks, seqlen, chunk_size,
33
+ # Strides
34
+ stride_states_batch, stride_states_chunk, stride_states_head, stride_states_dim,
35
+ stride_out_batch, stride_out_chunk, stride_out_head, stride_out_dim,
36
+ stride_final_states_batch, stride_final_states_head, stride_final_states_dim,
37
+ stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head,
38
+ stride_initstates_batch, stride_initstates_head, stride_initstates_dim,
39
+ stride_seq_idx_batch, stride_seq_idx_seqlen,
40
+ # Meta-parameters
41
+ HAS_INITSTATES: tl.constexpr,
42
+ HAS_SEQ_IDX: tl.constexpr,
43
+ BLOCK_SIZE: tl.constexpr,
44
+ ):
45
+ pid_b = tl.program_id(axis=1)
46
+ pid_h = tl.program_id(axis=2)
47
+ pid_m = tl.program_id(axis=0)
48
+ states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head
49
+ dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head
50
+ out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head
51
+ final_states_ptr += pid_b * stride_final_states_batch + pid_h * stride_final_states_head
52
+ if HAS_INITSTATES:
53
+ initstates_ptr += pid_b * stride_initstates_batch + pid_h * stride_initstates_head
54
+ if HAS_SEQ_IDX:
55
+ seq_idx_ptr += pid_b * stride_seq_idx_batch
56
+
57
+ offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
58
+ states_ptrs = states_ptr + offs_m * stride_states_dim
59
+ out_ptrs = out_ptr + offs_m * stride_out_dim
60
+ final_states_ptrs = final_states_ptr + offs_m * stride_final_states_dim
61
+
62
+ if not HAS_INITSTATES:
63
+ states = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32)
64
+ else:
65
+ initstates_ptrs = initstates_ptr + offs_m * stride_initstates_dim
66
+ states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
67
+ tl.store(out_ptrs, states, mask=offs_m < dim)
68
+ out_ptrs += stride_out_chunk
69
+ seq_idx = 0
70
+ for c in range(nchunks):
71
+ new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
72
+ dA_cs = tl.load(dA_cs_ptr).to(tl.float32)
73
+ scale = tl.exp(dA_cs)
74
+ if HAS_SEQ_IDX:
75
+ seq_idx_new = tl.load(seq_idx_ptr + (min((c + 1) * chunk_size, seqlen) - 1) * stride_seq_idx_seqlen)
76
+ scale = tl.where(seq_idx_new == seq_idx, scale, 0.0)
77
+ seq_idx = seq_idx_new
78
+ states = scale * states + new_states
79
+ if c < nchunks - 1:
80
+ tl.store(out_ptrs, states, mask=offs_m < dim)
81
+ else:
82
+ tl.store(final_states_ptrs, states, mask=offs_m < dim)
83
+ states_ptrs += stride_states_chunk
84
+ dA_cs_ptr += stride_dA_cs_chunk
85
+ out_ptrs += stride_out_chunk
86
+
87
+
88
+ @triton.autotune(
89
+ configs=[
90
+ triton.Config({'BLOCK_SIZE': 64}),
91
+ triton.Config({'BLOCK_SIZE': 128}),
92
+ triton.Config({'BLOCK_SIZE': 256}),
93
+ triton.Config({'BLOCK_SIZE': 512}),
94
+ triton.Config({'BLOCK_SIZE': 1024}),
95
+ triton.Config({'BLOCK_SIZE': 2048}),
96
+ ],
97
+ key=['dim'],
98
+ )
99
+ @triton.jit
100
+ def _state_passing_bwd_kernel(
101
+ # Pointers to matrices
102
+ dout_ptr, out_ptr, dA_cs_ptr, dfinal_states_ptr, seq_idx_ptr,
103
+ dstates_ptr, ddA_cs_ptr, dinitstates_ptr, states_converted_ptr,
104
+ # Matrix dimensions
105
+ dim, nchunks, seqlen, chunk_size,
106
+ # Strides
107
+ stride_dout_batch, stride_dout_chunk, stride_dout_head, stride_dout_dim,
108
+ stride_out_batch, stride_out_chunk, stride_out_head, stride_out_dim,
109
+ stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head,
110
+ stride_dfinal_states_batch, stride_dfinal_states_head, stride_dfinal_states_dim,
111
+ stride_seq_idx_batch, stride_seq_idx_seqlen,
112
+ stride_dstates_batch, stride_dstates_chunk, stride_dstates_head, stride_dstates_dim,
113
+ stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head,
114
+ stride_dinitstates_batch, stride_dinitstates_head, stride_dinitstates_dim,
115
+ # Meta-parameters
116
+ CONVERT_STATES: tl.constexpr,
117
+ HAS_DFINAL_STATES: tl.constexpr,
118
+ HAS_DINITSTATES: tl.constexpr,
119
+ HAS_SEQ_IDX: tl.constexpr,
120
+ BLOCK_SIZE: tl.constexpr,
121
+ ):
122
+ pid_b = tl.program_id(axis=1)
123
+ pid_h = tl.program_id(axis=2)
124
+ pid_m = tl.program_id(axis=0)
125
+ dstates_ptr += pid_b * stride_dstates_batch + pid_h * stride_dstates_head + (nchunks - 1) * stride_dstates_chunk
126
+ dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head + (nchunks - 1) * stride_dA_cs_chunk
127
+ ddA_cs_ptr += pid_b * stride_ddA_cs_batch + pid_h * stride_ddA_cs_head + (nchunks - 1) * stride_ddA_cs_chunk + pid_m
128
+ out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + (nchunks - 1) * stride_out_chunk
129
+ dout_ptr += pid_b * stride_dout_batch + pid_h * stride_dout_head + (nchunks - 1) * stride_dout_chunk
130
+ if CONVERT_STATES:
131
+ states_converted_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + (nchunks - 1) * stride_out_chunk
132
+ if HAS_DFINAL_STATES:
133
+ dfinal_states_ptr += pid_b * stride_dfinal_states_batch + pid_h * stride_dfinal_states_head
134
+ if HAS_DINITSTATES:
135
+ dinitstates_ptr += pid_b * stride_dinitstates_batch + pid_h * stride_dinitstates_head
136
+ if HAS_SEQ_IDX:
137
+ seq_idx_ptr += pid_b * stride_seq_idx_batch
138
+
139
+ offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
140
+ dstates_ptrs = dstates_ptr + offs_m * stride_dstates_dim
141
+ out_ptrs = out_ptr + offs_m * stride_out_dim
142
+ dout_ptrs = dout_ptr + offs_m * stride_dout_dim
143
+ if CONVERT_STATES:
144
+ states_converted_ptrs = states_converted_ptr + offs_m * stride_out_dim
145
+
146
+ if HAS_DFINAL_STATES:
147
+ dstates = tl.load(dfinal_states_ptr + offs_m * stride_dfinal_states_dim, mask=offs_m < dim, other=0.0).to(tl.float32)
148
+ else:
149
+ dstates = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32)
150
+ tl.store(dstates_ptrs, dstates, mask=offs_m < dim)
151
+ if HAS_SEQ_IDX:
152
+ seq_idx = tl.load(seq_idx_ptr + (seqlen - 1) * stride_seq_idx_seqlen)
153
+ dstates_ptrs -= stride_dstates_chunk
154
+ for c in range(nchunks - 1):
155
+ dA_cs = tl.load(dA_cs_ptr).to(tl.float32)
156
+ scale = tl.exp(dA_cs)
157
+ if HAS_SEQ_IDX:
158
+ seq_idx_new = tl.load(seq_idx_ptr + (((nchunks - c - 1) * chunk_size - 1) * stride_seq_idx_seqlen))
159
+ scale = tl.where(seq_idx_new == seq_idx, scale, 0.0)
160
+ seq_idx = seq_idx_new
161
+ out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
162
+ if CONVERT_STATES:
163
+ tl.store(states_converted_ptrs, out, mask=offs_m < dim)
164
+ ddA = tl.sum(out * dstates) * scale
165
+ tl.store(ddA_cs_ptr, ddA)
166
+ dout = tl.load(dout_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
167
+ dstates = scale * dstates + dout
168
+ tl.store(dstates_ptrs, dstates, mask=offs_m < dim)
169
+ dout_ptrs -= stride_dout_chunk
170
+ dstates_ptrs -= stride_dstates_chunk
171
+ dA_cs_ptr -= stride_dA_cs_chunk
172
+ ddA_cs_ptr -= stride_ddA_cs_chunk
173
+ out_ptrs -= stride_out_chunk
174
+ if CONVERT_STATES:
175
+ states_converted_ptrs -= stride_out_chunk
176
+ if CONVERT_STATES:
177
+ out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
178
+ tl.store(states_converted_ptrs, out, mask=offs_m < dim)
179
+ if not HAS_DINITSTATES:
180
+ tl.store(ddA_cs_ptr, 0.0)
181
+ else:
182
+ dA_cs = tl.load(dA_cs_ptr).to(tl.float32)
183
+ scale = tl.exp(dA_cs)
184
+ if HAS_SEQ_IDX:
185
+ scale = tl.where(seq_idx == 0, scale, 0.0)
186
+ out = tl.load(out_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
187
+ ddA = tl.sum(out * dstates) * scale
188
+ tl.store(ddA_cs_ptr, ddA)
189
+ dout = tl.load(dout_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
190
+ dstates = scale * dstates + dout
191
+ tl.store(dinitstates_ptr + offs_m * stride_dinitstates_dim, dstates, mask=offs_m < dim)
192
+
193
+
194
+ def _state_passing_fwd(states, dA_chunk_cumsum, initial_states=None, seq_idx=None, chunk_size=None,
195
+ out_dtype=None):
196
+ batch, nchunks, nheads, dim = states.shape
197
+ assert dA_chunk_cumsum.shape == (batch, nheads, nchunks)
198
+ if initial_states is not None:
199
+ assert initial_states.shape == (batch, nheads, dim)
200
+ if seq_idx is not None:
201
+ assert chunk_size is not None
202
+ seqlen = seq_idx.shape[-1]
203
+ assert seq_idx.shape == (batch, seqlen)
204
+ out_dtype = states.dtype if out_dtype is None else out_dtype
205
+ out = torch.empty((batch, nchunks, nheads, dim), device=states.device, dtype=out_dtype)
206
+ final_states = torch.empty((batch, nheads, dim), device=states.device, dtype=torch.float32)
207
+ grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), batch, nheads)
208
+ with torch.cuda.device(states.device.index):
209
+ _state_passing_fwd_kernel[grid](
210
+ states, out, final_states, dA_chunk_cumsum, initial_states, seq_idx,
211
+ dim, nchunks, seqlen if seq_idx is not None else 0, chunk_size if seq_idx is not None else 0,
212
+ states.stride(0), states.stride(1), states.stride(2), states.stride(3),
213
+ out.stride(0), out.stride(1), out.stride(2), out.stride(3),
214
+ final_states.stride(0), final_states.stride(1), final_states.stride(2),
215
+ dA_chunk_cumsum.stride(0), dA_chunk_cumsum.stride(2), dA_chunk_cumsum.stride(1),
216
+ *((initial_states.stride(0), initial_states.stride(1), initial_states.stride(2))
217
+ if initial_states is not None else (0, 0, 0)),
218
+ *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
219
+ HAS_INITSTATES=initial_states is not None,
220
+ HAS_SEQ_IDX=seq_idx is not None,
221
+ )
222
+ return out, final_states
223
+
224
+
225
+ def _state_passing_bwd(
226
+ states, dA_chunk_cumsum, dout, dfinal_states=None, seq_idx=None, has_initial_states=None,
227
+ dstates_dtype=None, states_dtype=None, chunk_size=None
228
+ ):
229
+ """
230
+ states contains the initial_states at index 0. The final states are not included in states.
231
+ """
232
+ batch, nchunks, nheads, dim = states.shape
233
+ assert dA_chunk_cumsum.shape == (batch, nheads, nchunks)
234
+ assert dout.shape == (batch, nchunks, nheads, dim)
235
+ if seq_idx is not None:
236
+ assert chunk_size is not None
237
+ seqlen = seq_idx.shape[-1]
238
+ assert seq_idx.shape == (batch, seqlen)
239
+ dstates = torch.empty_like(dout, dtype=dstates_dtype if dstates_dtype is not None else dout.dtype)
240
+ if states_dtype is not None and states_dtype != states.dtype:
241
+ states_converted = torch.empty_like(states, dtype=dstates_dtype if dstates_dtype is not None else dout.dtype)
242
+ assert states_converted.stride() == states.stride()
243
+ else:
244
+ states_converted = None
245
+ if has_initial_states:
246
+ dinitstates = torch.empty_like(dstates[:, 0])
247
+ else:
248
+ dinitstates = None
249
+ if dfinal_states is not None:
250
+ assert dfinal_states.shape == (batch, nheads, dim)
251
+ BLOCK_SIZE_min = 64
252
+ n_blocks = (dim + BLOCK_SIZE_min - 1) // BLOCK_SIZE_min
253
+ ddA_chunk_cumsum = torch.empty(batch, nheads, nchunks, n_blocks,
254
+ dtype=torch.float32, device=dA_chunk_cumsum.device)
255
+ grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), batch, nheads)
256
+ with torch.cuda.device(dout.device.index):
257
+ _state_passing_bwd_kernel[grid](
258
+ dout, states, dA_chunk_cumsum, dfinal_states, seq_idx,
259
+ dstates, ddA_chunk_cumsum, dinitstates, states_converted,
260
+ dim, nchunks, seqlen if seq_idx is not None else 0, chunk_size if seq_idx is not None else 0,
261
+ dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),
262
+ states.stride(0), states.stride(1), states.stride(2), states.stride(3),
263
+ dA_chunk_cumsum.stride(0), dA_chunk_cumsum.stride(2), dA_chunk_cumsum.stride(1),
264
+ *((dfinal_states.stride(0), dfinal_states.stride(1), dfinal_states.stride(2))
265
+ if dfinal_states is not None else (0, 0, 0)),
266
+ *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
267
+ dstates.stride(0), dstates.stride(1), dstates.stride(2), dstates.stride(3),
268
+ ddA_chunk_cumsum.stride(0), ddA_chunk_cumsum.stride(2), ddA_chunk_cumsum.stride(1),
269
+ *((dinitstates.stride(0), dinitstates.stride(1), dinitstates.stride(2))
270
+ if dinitstates is not None else (0, 0, 0)),
271
+ CONVERT_STATES=states_converted is not None,
272
+ HAS_DFINAL_STATES=dfinal_states is not None,
273
+ HAS_DINITSTATES=dinitstates is not None,
274
+ HAS_SEQ_IDX=seq_idx is not None,
275
+ )
276
+ BLOCK_SIZE_actual = _state_passing_bwd_kernel.best_config.kwargs["BLOCK_SIZE"]
277
+ n_valid_blocks = (dim + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual
278
+ ddA_chunk_cumsum = ddA_chunk_cumsum[..., :n_valid_blocks].sum(dim=-1).to(dtype=dA_chunk_cumsum.dtype)
279
+ if states_dtype is not None and states_dtype == states.dtype:
280
+ states_converted = states
281
+ return (dstates, ddA_chunk_cumsum, dinitstates) if states_dtype is None else (dstates, ddA_chunk_cumsum, dinitstates, states_converted)
282
+
283
+
284
+ class StatePassingFn(torch.autograd.Function):
285
+
286
+ @staticmethod
287
+ def forward(ctx, states, dA_chunk_cumsum, initial_states=None):
288
+ batch, nchunks, nheads, dim = states.shape
289
+ assert dA_chunk_cumsum.shape == (batch, nheads, nchunks)
290
+ if states.stride(-1) != 1:
291
+ states = states.contiguous()
292
+ out, final_states = _state_passing_fwd(states, dA_chunk_cumsum, initial_states)
293
+ ctx.save_for_backward(out, dA_chunk_cumsum)
294
+ ctx.has_initial_states = initial_states is not None
295
+ return out, final_states
296
+
297
+ @staticmethod
298
+ def backward(ctx, dout, dfinal_states):
299
+ out, dA_chunk_cumsum = ctx.saved_tensors
300
+ batch, nchunks, nheads, dim = out.shape
301
+ assert dout.shape == (batch, nchunks, nheads, dim)
302
+ assert dA_chunk_cumsum.shape == (batch, nheads, nchunks)
303
+ assert dfinal_states.shape == (batch, nheads, dim)
304
+ if dout.stride(-1) != 1:
305
+ dout = dout.contiguous()
306
+ dstates, ddA_chunk_cumsum, dinitstates = _state_passing_bwd(
307
+ out, dA_chunk_cumsum, dout, dfinal_states=dfinal_states , has_initial_states=ctx.has_initial_states
308
+ )
309
+ return dstates, ddA_chunk_cumsum, dinitstates
310
+
311
+
312
+ def state_passing(states, dA_chunk_cumsum, initial_states=None):
313
+ """
314
+ Argument:
315
+ states: (batch, nchunks, nheads, dim)
316
+ dA_chunk_cumsum: (batch, nheads, nchunks)
317
+ initial_states: (batch, nheads, dim)
318
+ Return:
319
+ out: (batch, nchunks, nheads, dim)
320
+ final_states: (batch, nheads, dim)
321
+ """
322
+ return StatePassingFn.apply(states, dA_chunk_cumsum, initial_states)
323
+
324
+
325
+ def state_passing_ref(states, dA_chunk_cumsum, initial_states=None):
326
+ """
327
+ Argument:
328
+ states: (batch, nchunks, nheads, dim)
329
+ dA_chunk_cumsum: (batch, nheads, nchunks)
330
+ initial_states: (batch, nheads, dim)
331
+ Return:
332
+ out: (batch, nchunks, nheads, dim)
333
+ final_states: (batch, nheads, dim)
334
+ """
335
+ if initial_states is None:
336
+ initial_states = torch.zeros_like(states[:, 0])
337
+ states = torch.cat([rearrange(initial_states, "b h d -> b 1 h d"), states], dim=1)
338
+ dA_chunk_cumsum = F.pad(dA_chunk_cumsum, (1, 0))
339
+ dA_chunk_cumsum = torch.cumsum(dA_chunk_cumsum, dim=-1)
340
+ nchunks = dA_chunk_cumsum.shape[-1]
341
+ # (batch, nheads, nchunks, nchunks)
342
+ dt_chunk_segment_sum = dA_chunk_cumsum[:, :, :, None] - dA_chunk_cumsum[:, :, None, :]
343
+ # (batch, nheads, nchunks, nchunks)
344
+ decay_chunk = torch.exp(dt_chunk_segment_sum)
345
+ causal_mask = torch.tril(torch.ones(nchunks, nchunks, device=states.device, dtype=bool), diagonal=0)
346
+ decay_chunk = decay_chunk.masked_fill(~causal_mask, 0)
347
+ out = torch.einsum("bhzc,bchd->bzhd", decay_chunk.to(dtype=states.dtype), states)
348
+ return out[:, :-1], out[:, -1]
build/torch210-cxx11-cu126-x86_64-linux/utils/__init__.py ADDED
File without changes
build/torch210-cxx11-cu126-x86_64-linux/utils/generation.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Albert Gu, Tri Dao.
2
+ import gc
3
+ import time
4
+ from collections import namedtuple
5
+ from dataclasses import dataclass, field
6
+ from functools import partial
7
+ from typing import Callable, Optional, Sequence, Union
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from einops import rearrange, repeat
12
+ from torch import Tensor
13
+ from torch.profiler import ProfilerActivity, profile, record_function
14
+ from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput, TextStreamer
15
+
16
+
17
+ @dataclass
18
+ class InferenceParams:
19
+ """Inference parameters that are passed to the main model in order
20
+ to efficienly calculate and store the context during inference."""
21
+
22
+ max_seqlen: int
23
+ max_batch_size: int
24
+ seqlen_offset: int = 0
25
+ batch_size_offset: int = 0
26
+ key_value_memory_dict: dict = field(default_factory=dict)
27
+ lengths_per_sample: Optional[Tensor] = None
28
+
29
+ def reset(self, max_seqlen, max_batch_size):
30
+ self.max_seqlen = max_seqlen
31
+ self.max_batch_size = max_batch_size
32
+ self.seqlen_offset = 0
33
+ if self.lengths_per_sample is not None:
34
+ self.lengths_per_sample.zero_()
35
+
36
+
37
+ def modify_logits_for_min_p_filtering(logits, min_p):
38
+ """Set the logits for none min_p values to -inf. Done in-place."""
39
+ if min_p <= 0.0 or min_p >= 1.0:
40
+ return
41
+ indices_to_remove = logits < min_p
42
+ logits.masked_fill_(indices_to_remove, float("-Inf"))
43
+ # https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
44
+ # https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231
45
+ def modify_logits_for_top_k_filtering(logits, top_k):
46
+ """Set the logits for none top-k values to -inf. Done in-place."""
47
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
48
+ logits.masked_fill_(indices_to_remove, float("-Inf"))
49
+
50
+
51
+ # https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
52
+ # https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170
53
+ def modify_logits_for_top_p_filtering(logits, top_p):
54
+ """Set the logits for none top-p values to -inf. Done in-place."""
55
+ if top_p <= 0.0 or top_p >= 1.0:
56
+ return
57
+ # First sort and calculate cumulative sum of probabilities.
58
+ sorted_logits, sorted_indices = torch.sort(logits, descending=False)
59
+ cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
60
+ # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
61
+ sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
62
+ # scatter sorted tensors to original indexing
63
+ indices_to_remove = sorted_indices_to_remove.scatter(
64
+ 1, sorted_indices, sorted_indices_to_remove
65
+ )
66
+ logits.masked_fill_(indices_to_remove, float("-inf"))
67
+
68
+
69
+ def modify_logit_for_repetition_penalty(logits, prev_output_tokens, repetition_penalty=1.0):
70
+ """Apply repetition penalty. See https://arxiv.org/abs/1909.05858
71
+ logits: (batch_size, vocab_size)
72
+ prev_output_tokens: (batch_size, seq_len)
73
+ """
74
+ if repetition_penalty == 1.0:
75
+ return logits
76
+ score = torch.gather(logits, 1, prev_output_tokens)
77
+ # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
78
+ score = torch.where(score < 0, score * repetition_penalty, score / repetition_penalty)
79
+ logits.scatter_(1, prev_output_tokens, score)
80
+ return logits
81
+
82
+
83
+ def sample(logits, top_k=1, top_p=0.0, min_p=0.0, temperature=1.0):
84
+ """Sample from top-k logits.
85
+ Arguments:
86
+ logits: Tensor of shape (batch_size, vocab_size)
87
+ """
88
+ if top_k == 1: # Short-circuit for greedy decoding
89
+ return logits.argmax(dim=-1)
90
+ else:
91
+ if top_p > 0.0:
92
+ assert top_p <= 1.0, "top-p should be in (0, 1]."
93
+ if top_k > 0:
94
+ top_k = min(top_k, logits.size(-1)) # Safety check
95
+ logits_top, indices = torch.topk(logits, top_k, dim=-1)
96
+ if temperature != 1.0:
97
+ logits_top /= temperature
98
+ modify_logits_for_top_p_filtering(logits_top, top_p)
99
+ return indices[
100
+ torch.arange(indices.shape[0], device=indices.device),
101
+ torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1),
102
+ ]
103
+ else:
104
+ if min_p > 0.0:
105
+ logits_top = logits.clone()
106
+ max_prob = logits_top[..., 0].item()
107
+ min_prob = max_prob * min_p
108
+ modify_logits_for_min_p_filtering(logits_top, min_prob)
109
+ if temperature != 1.0:
110
+ logits_top /= temperature
111
+ return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1)
112
+ # Clone so that when we modify for top_p we don't change the original logits
113
+ logits_top = logits / temperature if temperature != 1.0 else logits.clone()
114
+ modify_logits_for_top_p_filtering(logits_top, top_p)
115
+ return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(
116
+ dim=-1
117
+ )
118
+
119
+
120
+ @torch.inference_mode()
121
+ def decode(
122
+ input_ids,
123
+ model,
124
+ max_length,
125
+ top_k=1,
126
+ top_p=0.0,
127
+ min_p=0.0,
128
+ temperature=1.0,
129
+ repetition_penalty=1.0,
130
+ eos_token_id=None,
131
+ teacher_outputs=None,
132
+ vocab_size=None,
133
+ cg=False,
134
+ enable_timing=False,
135
+ output_scores=False,
136
+ streamer: Optional[TextStreamer] = None
137
+ ):
138
+ """Decoding, either greedy or with top-k or top-p sampling.
139
+ If top-k = 0, don't limit the number of candidates (pure sampling).
140
+ Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
141
+ then top-p.
142
+ We assume that all sequences in the same batch have the same length.
143
+
144
+ Arguments:
145
+ input_ids: (batch, seq_len)
146
+ max_length: int
147
+ teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the
148
+ logits, the next token is taken from the teacher_outputs. Useful for testing.
149
+ Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields:
150
+ sequences: (batch, max_length)
151
+ scores: tuples of (batch, vocab_size)
152
+ """
153
+ if streamer is not None:
154
+ streamer.put(input_ids.cpu())
155
+
156
+ batch_size, seqlen_og = input_ids.shape
157
+ teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0
158
+ if cg:
159
+ if not hasattr(model, "_decoding_cache"):
160
+ model._decoding_cache = None
161
+ model._decoding_cache = update_graph_cache(
162
+ model,
163
+ model._decoding_cache,
164
+ batch_size,
165
+ seqlen_og,
166
+ max_length,
167
+ )
168
+ inference_params = model._decoding_cache.inference_params
169
+ inference_params.reset(max_length, batch_size)
170
+ else:
171
+ inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size)
172
+
173
+ def get_logits(input_ids, inference_params):
174
+ decoding = inference_params.seqlen_offset > 0
175
+ if decoding:
176
+ position_ids = torch.full(
177
+ (batch_size, 1),
178
+ inference_params.seqlen_offset,
179
+ dtype=torch.long,
180
+ device=input_ids.device,
181
+ )
182
+ else:
183
+ position_ids = None
184
+ if not cg or not decoding:
185
+ logits = model(
186
+ input_ids,
187
+ position_ids=position_ids,
188
+ inference_params=inference_params,
189
+ num_last_tokens=1,
190
+ ).logits.squeeze(dim=1)
191
+ else:
192
+ logits = model._decoding_cache.run(
193
+ input_ids, position_ids, inference_params.seqlen_offset
194
+ ).squeeze(dim=1)
195
+ return logits[..., :vocab_size] if vocab_size is not None else logits
196
+
197
+ def sample_tokens(logits, inference_params):
198
+ if teacher_outputs is None or teacher_output_len <= inference_params.seqlen_offset:
199
+ token = sample(logits, top_k=top_k, top_p=top_p, min_p=min_p, temperature=temperature)
200
+ else:
201
+ token = teacher_outputs[:, inference_params.seqlen_offset]
202
+ # return rearrange(token, "b -> b 1")
203
+ return token.unsqueeze(1)
204
+
205
+ def should_stop(current_token, inference_params):
206
+ if inference_params.seqlen_offset == 0:
207
+ return False
208
+ if eos_token_id is not None and (current_token == eos_token_id).all():
209
+ return True
210
+ if inference_params.seqlen_offset >= max_length - 1:
211
+ return True
212
+ return False
213
+
214
+ start = torch.cuda.Event(enable_timing=enable_timing)
215
+ end = torch.cuda.Event(enable_timing=enable_timing)
216
+
217
+ if enable_timing:
218
+ start.record()
219
+ scores, sequences = [], [input_ids]
220
+ sequences_cat = input_ids
221
+ while not should_stop(sequences[-1], inference_params):
222
+ logits = get_logits(sequences[-1], inference_params)
223
+ if output_scores:
224
+ scores.append(logits.clone())
225
+ inference_params.seqlen_offset += sequences[-1].shape[1]
226
+ if repetition_penalty == 1.0:
227
+ sampled_tokens = sample_tokens(logits, inference_params)
228
+ else:
229
+ logits = modify_logit_for_repetition_penalty(
230
+ logits, sequences_cat, repetition_penalty
231
+ )
232
+ sampled_tokens = sample_tokens(logits, inference_params)
233
+ sequences_cat = torch.cat([sequences_cat, sampled_tokens], dim=1)
234
+ sequences.append(sampled_tokens)
235
+ if streamer is not None:
236
+ streamer.put(sampled_tokens.cpu())
237
+ if streamer is not None:
238
+ streamer.end()
239
+ if enable_timing:
240
+ end.record()
241
+ torch.cuda.synchronize()
242
+ print(f"Prompt processing + decoding time: {(start.elapsed_time(end)):.0f}ms")
243
+ output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
244
+ return output_cls(sequences=torch.cat(sequences, dim=1), scores=tuple(scores))
245
+
246
+
247
+ class GenerationMixin:
248
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
249
+ raise NotImplementedError
250
+
251
+ def generate(
252
+ self,
253
+ input_ids,
254
+ max_length,
255
+ top_k=1,
256
+ top_p=0.0,
257
+ min_p=0.0,
258
+ temperature=1.0,
259
+ return_dict_in_generate=False,
260
+ output_scores=False,
261
+ **kwargs,
262
+ ):
263
+ output = decode(
264
+ input_ids, self, max_length, top_k=top_k, top_p=top_p, min_p = min_p, temperature=temperature, output_scores=output_scores, **kwargs
265
+ )
266
+ if not output_scores:
267
+ output.scores = None
268
+ return output if return_dict_in_generate else output.sequences
269
+
270
+
271
+ @dataclass
272
+ class DecodingCGCache:
273
+ max_batch_size: int = 0
274
+ max_seqlen: int = 0
275
+ device = None
276
+ dtype = None
277
+ callables: dict = field(default_factory=dict)
278
+ mempool = None
279
+ inference_params: Optional[InferenceParams] = None
280
+ run: Optional[Callable] = None
281
+
282
+
283
+ @torch.inference_mode()
284
+ def update_graph_cache(
285
+ model,
286
+ cache,
287
+ batch_size,
288
+ seqlen_og,
289
+ max_seqlen,
290
+ decoding_seqlens=(1,),
291
+ dtype=None,
292
+ n_warmups=2,
293
+ ):
294
+ if cache is None:
295
+ cache = DecodingCGCache()
296
+ param_example = next(iter(model.parameters()))
297
+ device = param_example.device
298
+ if dtype is None:
299
+ dtype = param_example.dtype
300
+ if (
301
+ (device, dtype) != (cache.device, cache.dtype)
302
+ or batch_size > cache.max_batch_size
303
+ or max_seqlen > cache.max_seqlen
304
+ ): # Invalidate the cache
305
+ cache.callables = {}
306
+ cache.mempool = None
307
+ cache.inference_params = None
308
+ gc.collect()
309
+ cache.device, cache.dtype = device, dtype
310
+ cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen
311
+ assert hasattr(model, "allocate_inference_cache"), "CUDA graph decoding requires that the model has a method allocate_inference_cache"
312
+ inf_cache = model.allocate_inference_cache(batch_size, max_seqlen, dtype)
313
+ lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device)
314
+ cache.inference_params = InferenceParams(
315
+ max_seqlen=max_seqlen,
316
+ max_batch_size=batch_size,
317
+ seqlen_offset=seqlen_og,
318
+ key_value_memory_dict=inf_cache,
319
+ lengths_per_sample=lengths_per_sample,
320
+ )
321
+ cache.mempool = torch.cuda.graphs.graph_pool_handle()
322
+ for decoding_seqlen in decoding_seqlens:
323
+ if (batch_size, decoding_seqlen) not in cache.callables:
324
+ cache.callables[batch_size, decoding_seqlen] = capture_graph(
325
+ model,
326
+ cache.inference_params,
327
+ batch_size,
328
+ max_seqlen,
329
+ decoding_seqlen=decoding_seqlen,
330
+ mempool=cache.mempool,
331
+ n_warmups=n_warmups,
332
+ )
333
+
334
+ def dispatch(input_ids, position_ids, seqlen):
335
+ batch_size, decoding_seqlen = input_ids.shape[:2]
336
+ return cache.callables[batch_size, decoding_seqlen](input_ids, position_ids, seqlen)
337
+
338
+ cache.run = dispatch
339
+ cache.inference_params.seqlen_offset = 0 # Reset so it's not confusing
340
+ return cache
341
+
342
+
343
+ def capture_graph(
344
+ model, inference_params, batch_size, max_seqlen, decoding_seqlen=1, mempool=None, n_warmups=2
345
+ ):
346
+ device = next(iter(model.parameters())).device
347
+ input_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device)
348
+ position_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device)
349
+ seqlen_offset_og = inference_params.seqlen_offset
350
+ inference_params.seqlen_offset = max_seqlen - decoding_seqlen
351
+ inference_params.lengths_per_sample[:] = inference_params.seqlen_offset
352
+
353
+ # Warmup before capture
354
+ s = torch.cuda.Stream()
355
+ s.wait_stream(torch.cuda.current_stream())
356
+ with torch.cuda.stream(s):
357
+ for _ in range(n_warmups):
358
+ logits = model(
359
+ input_ids,
360
+ position_ids=position_ids,
361
+ inference_params=inference_params,
362
+ num_last_tokens=decoding_seqlen,
363
+ ).logits
364
+ s.synchronize()
365
+ # This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0,
366
+ # which requires that graph launch and non-captured launch to not overlap (I think,
367
+ # that's how I interpret the documentation). I'm not sure if this is required.
368
+ if torch.distributed.is_initialized():
369
+ torch.distributed.barrier()
370
+ torch.cuda.current_stream().wait_stream(s)
371
+ # Captures the graph
372
+ # To allow capture, automatically sets a side stream as the current stream in the context
373
+ graph = torch.cuda.CUDAGraph()
374
+ with torch.cuda.graph(graph, pool=mempool):
375
+ logits = model(
376
+ input_ids,
377
+ position_ids=position_ids,
378
+ inference_params=inference_params,
379
+ num_last_tokens=decoding_seqlen,
380
+ ).logits
381
+
382
+ def run(new_input_ids, new_position_ids, seqlen):
383
+ inference_params.lengths_per_sample[:] = seqlen
384
+ input_ids.copy_(new_input_ids)
385
+ position_ids.copy_(new_position_ids)
386
+ graph.replay()
387
+ return logits.clone()
388
+
389
+ inference_params.seqlen_offset = seqlen_offset_og
390
+ return run
build/torch210-cxx11-cu126-x86_64-linux/utils/hf.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ import torch
4
+
5
+ from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
6
+ from transformers.utils.hub import cached_file
7
+
8
+
9
+ def load_config_hf(model_name):
10
+ resolved_archive_file = cached_file(model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False)
11
+ return json.load(open(resolved_archive_file))
12
+
13
+
14
+ def load_state_dict_hf(model_name, device=None, dtype=None):
15
+ # If not fp32, then we don't want to load directly to the GPU
16
+ mapped_device = "cpu" if dtype not in [torch.float32, None] else device
17
+ resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False)
18
+ return torch.load(resolved_archive_file, map_location=mapped_device)
19
+ # Convert dtype before moving to GPU to save memory
20
+ if dtype is not None:
21
+ state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()}
22
+ state_dict = {k: v.to(device=device) for k, v in state_dict.items()}
23
+ return state_dict
build/torch210-cxx11-cu126-x86_64-linux/utils/torch.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from functools import partial
3
+ from typing import Callable
4
+
5
+ def custom_amp_decorator(dec: Callable, cuda_amp_deprecated: bool):
6
+ def decorator(*args, **kwargs):
7
+ if cuda_amp_deprecated:
8
+ kwargs["device_type"] = "cuda"
9
+ return dec(*args, **kwargs)
10
+ return decorator
11
+
12
+
13
+ if hasattr(torch.amp, "custom_fwd"): # type: ignore[attr-defined]
14
+ deprecated = True
15
+ from torch.amp import custom_fwd, custom_bwd # type: ignore[attr-defined]
16
+ else:
17
+ deprecated = False
18
+ from torch.cuda.amp import custom_fwd, custom_bwd
19
+
20
+ custom_fwd = custom_amp_decorator(custom_fwd, deprecated)
21
+ custom_bwd = custom_amp_decorator(custom_bwd, deprecated)
build/torch210-cxx11-cu128-x86_64-linux/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __version__ = "2.2.4"
2
+
3
+ from .ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
4
+ from .modules.mamba_simple import Mamba
5
+ from .modules.mamba2 import Mamba2
6
+ from .models.mixer_seq_simple import MambaLMHeadModel
7
+
8
+ __all__ = [
9
+ "selective_scan_fn",
10
+ "mamba_inner_fn",
11
+ "Mamba",
12
+ "Mamba2",
13
+ "MambaLMHeadModel",
14
+ ]
build/torch210-cxx11-cu128-x86_64-linux/_mamba_ssm_b2a7fd5.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2cebad781003a612eea29f35ebaf4a1905057ac6e20cdd12a216e4e403b34095
3
+ size 610662240
build/torch210-cxx11-cu128-x86_64-linux/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _mamba_ssm_b2a7fd5
3
+ ops = torch.ops._mamba_ssm_b2a7fd5
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_mamba_ssm_b2a7fd5::{op_name}"
build/torch210-cxx11-cu128-x86_64-linux/distributed/__init__.py ADDED
File without changes
build/torch210-cxx11-cu128-x86_64-linux/distributed/distributed_utils.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ from torch import Tensor
5
+ from torch.distributed import ProcessGroup
6
+
7
+ # `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
8
+ # `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
9
+ # version of PyTorch. The following 4 lines are for backward compatibility with
10
+ # older PyTorch.
11
+ if "all_gather_into_tensor" not in dir(torch.distributed):
12
+ torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
13
+ if "reduce_scatter_tensor" not in dir(torch.distributed):
14
+ torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base
15
+
16
+
17
+ # Raw operation, does not support autograd, but does support async
18
+ def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
19
+ world_size = torch.distributed.get_world_size(process_group)
20
+ output = torch.empty(
21
+ world_size * input_.shape[0], *input_.shape[1:], dtype=input_.dtype, device=input_.device
22
+ )
23
+ handle = torch.distributed.all_gather_into_tensor(
24
+ output, input_.contiguous(), group=process_group, async_op=async_op
25
+ )
26
+ return output, handle
27
+
28
+
29
+ # Raw operation, does not support autograd, but does support async
30
+ def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
31
+ world_size = torch.distributed.get_world_size(process_group)
32
+ assert input_.shape[0] % world_size == 0
33
+ output = torch.empty(
34
+ input_.shape[0] // world_size, *input_.shape[1:], dtype=input_.dtype, device=input_.device
35
+ )
36
+ handle = torch.distributed.reduce_scatter_tensor(
37
+ output, input_.contiguous(), group=process_group, async_op=async_op
38
+ )
39
+ return output, handle
40
+
41
+
42
+ # Raw operation, does not support autograd, but does support async
43
+ def all_reduce_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
44
+ input_ = input_.contiguous()
45
+ handle = torch.distributed.all_reduce(input_, group=process_group, async_op=async_op)
46
+ return input_, handle
47
+
48
+
49
+ class AllGatherFunc(torch.autograd.Function):
50
+ """Gather the input from sequence parallel region and concatenate."""
51
+
52
+ @staticmethod
53
+ def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
54
+ ctx.process_group = process_group
55
+ output, _ = all_gather_raw(input_, process_group)
56
+ return output
57
+
58
+ @staticmethod
59
+ def backward(ctx, grad_output: Tensor):
60
+ grad_input, _ = reduce_scatter_raw(grad_output, ctx.process_group)
61
+ return grad_input, None
62
+
63
+
64
+ # Supports autograd, but does not support async
65
+ all_gather = AllGatherFunc.apply
66
+
67
+
68
+ class ReduceScatterFunc(torch.autograd.Function):
69
+ """Reduce scatter the input from the sequence parallel region and concatenate."""
70
+
71
+ @staticmethod
72
+ def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
73
+ ctx.process_group = process_group
74
+ output, _ = reduce_scatter_raw(input_, process_group)
75
+ return output
76
+
77
+ @staticmethod
78
+ def backward(ctx, grad_output: Tensor):
79
+ grad_input, _ = all_gather_raw(grad_output, ctx.process_group)
80
+ return grad_input, None
81
+
82
+
83
+ # Supports autograd, but does not support async
84
+ reduce_scatter = ReduceScatterFunc.apply
85
+
86
+
87
+ class AllReduceFunc(torch.autograd.Function):
88
+ """Gather the input from sequence parallel region and concatenate."""
89
+
90
+ @staticmethod
91
+ def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
92
+ ctx.process_group = process_group
93
+ output, _ = all_reduce_raw(input_, process_group)
94
+ return output
95
+
96
+ @staticmethod
97
+ def backward(ctx, grad_output: Tensor):
98
+ return grad_output, None
99
+
100
+
101
+ # Supports autograd, but does not support async
102
+ all_reduce = AllReduceFunc.apply
103
+
104
+
105
+ def sync_shared_params(model: torch.nn.Module, process_group: ProcessGroup):
106
+ # We want to iterate over parameters with _shared_params=True in the same order,
107
+ # as different ranks might have different number of parameters (e.g., only rank 0 has bias).
108
+ pamams_shared = {
109
+ name: p for name, p in model.named_parameters() if getattr(p, "_shared_params", False)
110
+ }
111
+ for _, p in sorted(pamams_shared.items()):
112
+ with torch.no_grad():
113
+ # Broadcast needs src to be global rank, not group rank
114
+ torch.distributed.broadcast(
115
+ p, src=torch.distributed.get_global_rank(process_group, 0), group=process_group
116
+ )
117
+
118
+
119
+ # Ref: https://github.com/NVIDIA/Megatron-LM/blob/52e636888cccc41e931251c417a7181fc36de926/megatron/optimizer/optimizer.py#L256
120
+ def allreduce_sequence_parallel_grad(model: torch.nn.Module, process_group: ProcessGroup):
121
+ # We want to iterate over parameters with _sequence_parallel=True in the same order,
122
+ # as different ranks might have different number of parameters (e.g., only rank 0 has bias).
123
+ params_seqparallel = {
124
+ name: p for name, p in model.named_parameters() if getattr(p, "_sequence_parallel", False)
125
+ }
126
+ grads = [p.grad for _, p in sorted(params_seqparallel.items())]
127
+ if grads:
128
+ with torch.no_grad():
129
+ coalesced = torch._utils._flatten_dense_tensors(grads)
130
+ torch.distributed.all_reduce(coalesced, group=process_group)
131
+ for buf, synced in zip(grads, torch._utils._unflatten_dense_tensors(coalesced, grads)):
132
+ buf.copy_(synced)
133
+
134
+
135
+ def get_dim_for_local_rank(dim: int, world_size: int, local_rank: int, multiple_of: int = 1) -> int:
136
+ """Get the dim for the local rank derived from splitting dim on world_size processes.
137
+
138
+ The split may not be even across the world_size processes.
139
+ """
140
+ multiple = dim // multiple_of
141
+ div = multiple // world_size
142
+ mod = multiple % world_size
143
+ local_multiple = div + int(local_rank < mod)
144
+ return local_multiple * multiple_of
build/torch210-cxx11-cu128-x86_64-linux/distributed/tensor_parallel.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao.
2
+ # The TensorParallel linear modules are inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/layers.py
3
+ from typing import Optional
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from torch import Tensor
9
+ from torch.distributed import ProcessGroup
10
+ from ..utils.torch import custom_bwd, custom_fwd
11
+
12
+ from einops import rearrange
13
+
14
+ from ..distributed.distributed_utils import (
15
+ all_gather_raw,
16
+ all_reduce,
17
+ all_reduce_raw,
18
+ reduce_scatter,
19
+ reduce_scatter_raw,
20
+ )
21
+
22
+
23
+ class ParallelLinearFunc(torch.autograd.Function):
24
+ @staticmethod
25
+ @custom_fwd
26
+ def forward(ctx, x, weight, bias, process_group=None, sequence_parallel=True):
27
+ """
28
+ If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel
29
+ with sequence parallelism: we do an all_gather_raw of x before doing the matmul.
30
+ """
31
+ ctx.compute_weight_gradient = weight.requires_grad
32
+ ctx.process_group = process_group
33
+ ctx.sequence_parallel = sequence_parallel
34
+
35
+ if torch.is_autocast_enabled():
36
+ x = x.to(dtype=torch.get_autocast_gpu_dtype())
37
+ x = x.contiguous()
38
+ if process_group is not None and sequence_parallel:
39
+ # We want to kick off the all_gather early, before weight dtype conversion
40
+ total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
41
+ else:
42
+ total_x = x
43
+
44
+ if torch.is_autocast_enabled():
45
+ weight = weight.to(dtype=torch.get_autocast_gpu_dtype())
46
+ bias = bias.to(dtype=torch.get_autocast_gpu_dtype()) if bias is not None else None
47
+ weight = weight.contiguous()
48
+ if process_group is not None and sequence_parallel:
49
+ handle_x.wait()
50
+ batch_shape, n = total_x.shape[:-1], total_x.shape[-1]
51
+ batch_dim = batch_shape.numel()
52
+ # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174
53
+ output = F.linear(total_x, weight, bias)
54
+ if ctx.compute_weight_gradient:
55
+ ctx.save_for_backward(x, weight)
56
+ else:
57
+ ctx.save_for_backward(weight)
58
+ return output
59
+
60
+ @staticmethod
61
+ @custom_bwd
62
+ def backward(ctx, grad_output):
63
+ grad_output = grad_output.contiguous()
64
+ process_group = ctx.process_group
65
+ sequence_parallel = ctx.sequence_parallel
66
+ if ctx.compute_weight_gradient:
67
+ x, weight = ctx.saved_tensors
68
+ if process_group is not None and sequence_parallel:
69
+ total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
70
+ else:
71
+ total_x = x
72
+ else:
73
+ (weight,) = ctx.saved_tensors
74
+ total_x = None
75
+ batch_shape = grad_output.shape[:-1]
76
+ batch_dim = batch_shape.numel()
77
+ grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
78
+ if ctx.needs_input_grad[0]:
79
+ grad_input = F.linear(grad_output, weight.t())
80
+ grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
81
+ if process_group is not None:
82
+ reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw
83
+ grad_input, handle_grad_input = reduce_fn(grad_input, process_group, async_op=True)
84
+ else:
85
+ grad_input = None
86
+ if ctx.needs_input_grad[1]:
87
+ assert ctx.compute_weight_gradient
88
+ if process_group is not None and sequence_parallel:
89
+ handle_x.wait()
90
+ grad_weight = torch.einsum(
91
+ "bo,bi->oi", grad_output, total_x.reshape(batch_dim, total_x.shape[-1])
92
+ )
93
+ else:
94
+ grad_weight = None
95
+ grad_bias = grad_output.sum(dim=0) if ctx.needs_input_grad[2] else None
96
+ if process_group is not None and ctx.needs_input_grad[0]:
97
+ handle_grad_input.wait()
98
+ return grad_input, grad_weight, grad_bias, None, None
99
+
100
+
101
+ def parallel_linear_func(
102
+ x: Tensor,
103
+ weight: Tensor,
104
+ bias: Optional[Tensor] = None,
105
+ process_group: Optional[ProcessGroup] = None,
106
+ sequence_parallel: bool = True,
107
+ ):
108
+ return ParallelLinearFunc.apply(x, weight, bias, process_group, sequence_parallel)
109
+
110
+
111
+ class ColumnParallelLinear(nn.Linear):
112
+ def __init__(
113
+ self,
114
+ in_features: int,
115
+ out_features: int,
116
+ process_group: ProcessGroup,
117
+ bias: bool = True,
118
+ sequence_parallel=True,
119
+ multiple_of=1,
120
+ device=None,
121
+ dtype=None,
122
+ ) -> None:
123
+ world_size = torch.distributed.get_world_size(process_group)
124
+ if out_features % multiple_of:
125
+ raise ValueError(f"out_features ({out_features}) must be a multiple of {multiple_of}")
126
+ multiple = out_features // multiple_of
127
+ # We want to split @multiple across world_size, but it could be an uneven split
128
+ div = multiple // world_size
129
+ mod = multiple % world_size
130
+ # The first @mod ranks get @div + 1 copies, the rest get @div copies
131
+ local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)
132
+ super().__init__(
133
+ in_features, local_multiple * multiple_of, bias=bias, device=device, dtype=dtype
134
+ )
135
+ self.process_group = process_group
136
+ self.sequence_parallel = sequence_parallel
137
+
138
+ def forward(self, x):
139
+ # If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
140
+ # we do an all_gather of x before doing the matmul.
141
+ # If not, then the input is already gathered.
142
+ return parallel_linear_func(
143
+ x,
144
+ self.weight,
145
+ self.bias,
146
+ process_group=self.process_group,
147
+ sequence_parallel=self.sequence_parallel,
148
+ )
149
+
150
+
151
+ class RowParallelLinear(nn.Linear):
152
+ def __init__(
153
+ self,
154
+ in_features: int,
155
+ out_features: int,
156
+ process_group: ProcessGroup,
157
+ bias: bool = True,
158
+ sequence_parallel=True,
159
+ multiple_of=1,
160
+ device=None,
161
+ dtype=None,
162
+ ) -> None:
163
+ world_size = torch.distributed.get_world_size(process_group)
164
+ rank = torch.distributed.get_rank(process_group)
165
+ if in_features % multiple_of:
166
+ raise ValueError(f"in_features ({in_features}) must be a multiple of {multiple_of}")
167
+ multiple = in_features // multiple_of
168
+ # We want to split @multiple across world_size, but it could be an uneven split
169
+ div = multiple // world_size
170
+ mod = multiple % world_size
171
+ # The first @mod ranks get @div + 1 copies, the rest get @div copies
172
+ local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)
173
+ # Only rank 0 will have bias
174
+ super().__init__(
175
+ local_multiple * multiple_of,
176
+ out_features,
177
+ bias=bias and rank == 0,
178
+ device=device,
179
+ dtype=dtype,
180
+ )
181
+ self.process_group = process_group
182
+ self.sequence_parallel = sequence_parallel
183
+
184
+ def forward(self, x):
185
+ """
186
+ We're doing Tensor Parallel with sequence parallelism: we do the matmul and then
187
+ a reduce_scatter of the result.
188
+ """
189
+ out = parallel_linear_func(x, self.weight, self.bias)
190
+ reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
191
+ return reduce_fn(out, self.process_group)
192
+
193
+
194
+ class VocabParallelEmbedding(nn.Embedding):
195
+ def __init__(self, num_embeddings, *args, process_group=None, padding_idx=None, **kwargs):
196
+ self.process_group = process_group
197
+ if process_group is not None:
198
+ world_size = torch.distributed.get_world_size(process_group)
199
+ if num_embeddings % world_size != 0:
200
+ raise ValueError(
201
+ f"num_embeddings ({num_embeddings}) must be divisible by "
202
+ f"world_size ({world_size})"
203
+ )
204
+ if world_size > 1 and padding_idx is not None:
205
+ raise RuntimeError("ParallelEmbedding does not support padding_idx")
206
+ else:
207
+ world_size = 1
208
+ super().__init__(num_embeddings // world_size, *args, padding_idx=padding_idx, **kwargs)
209
+
210
+ def forward(self, input: Tensor) -> Tensor:
211
+ if self.process_group is None:
212
+ return super().forward(input)
213
+ else:
214
+ rank = torch.distributed.get_rank(self.process_group)
215
+ vocab_size = self.num_embeddings
216
+ vocab_start_index, vocab_end_index = rank * vocab_size, (rank + 1) * vocab_size
217
+ # Create a mask of valid vocab ids (1 means it needs to be masked).
218
+ input_ids_mask = (input < vocab_start_index) | (input >= vocab_end_index)
219
+ input = input - vocab_start_index
220
+ input[input_ids_mask] = 0
221
+ embeddings = super().forward(input)
222
+ embeddings[input_ids_mask] = 0.0
223
+ return embeddings
224
+
225
+
226
+ class ColumnParallelEmbedding(nn.Embedding):
227
+ def __init__(self, num_embeddings, embedding_dim, *args, process_group=None, **kwargs):
228
+ self.process_group = process_group
229
+ if process_group is not None:
230
+ world_size = torch.distributed.get_world_size(process_group)
231
+ if embedding_dim % world_size != 0:
232
+ raise ValueError(
233
+ f"embedding_dim ({embedding_dim}) must be divisible by "
234
+ f"world_size ({world_size})"
235
+ )
236
+ else:
237
+ world_size = 1
238
+ super().__init__(num_embeddings, embedding_dim // world_size, *args, **kwargs)
239
+
240
+
241
+ class ParallelEmbeddings(nn.Module):
242
+ def __init__(
243
+ self,
244
+ embed_dim,
245
+ vocab_size,
246
+ max_position_embeddings,
247
+ process_group,
248
+ padding_idx=None,
249
+ sequence_parallel=True,
250
+ device=None,
251
+ dtype=None,
252
+ ):
253
+ """
254
+ If max_position_embeddings <= 0, there's no position embeddings
255
+ """
256
+ factory_kwargs = {"device": device, "dtype": dtype}
257
+ super().__init__()
258
+ self.process_group = process_group
259
+ self.sequence_parallel = sequence_parallel
260
+ self.word_embeddings = VocabParallelEmbedding(
261
+ vocab_size,
262
+ embed_dim,
263
+ padding_idx=padding_idx,
264
+ process_group=process_group,
265
+ **factory_kwargs,
266
+ )
267
+ self.max_position_embeddings = max_position_embeddings
268
+ if self.max_position_embeddings > 0:
269
+ self.position_embeddings = ColumnParallelEmbedding(
270
+ max_position_embeddings, embed_dim, process_group=process_group, **factory_kwargs
271
+ )
272
+
273
+ def forward(self, input_ids, position_ids=None, combine_batch_seqlen_dim=False):
274
+ """
275
+ input_ids: (batch, seqlen)
276
+ position_ids: (batch, seqlen)
277
+ """
278
+ batch_size, seqlen = input_ids.shape
279
+ world_size = torch.distributed.get_world_size(self.process_group)
280
+ embeddings = self.word_embeddings(input_ids)
281
+ if self.max_position_embeddings > 0:
282
+ if position_ids is None:
283
+ position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
284
+ position_embeddings = self.position_embeddings(position_ids)
285
+ if world_size <= 1:
286
+ embeddings = embeddings + position_embeddings
287
+ else:
288
+ partition_dim = self.position_embeddings.embedding_dim
289
+ rank = torch.distributed.get_rank(self.process_group)
290
+ embeddings[
291
+ ..., rank * partition_dim : (rank + 1) * partition_dim
292
+ ] += position_embeddings
293
+ if combine_batch_seqlen_dim:
294
+ embeddings = rearrange(embeddings, "b s d -> (b s) d")
295
+ reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
296
+ return embeddings if world_size <= 1 else reduce_fn(embeddings, self.process_group)
build/torch210-cxx11-cu128-x86_64-linux/mamba_ssm/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ctypes
2
+ import sys
3
+
4
+ import importlib
5
+ from pathlib import Path
6
+ from types import ModuleType
7
+
8
+ def _import_from_path(file_path: Path) -> ModuleType:
9
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
10
+ # it would also be used for other imports. So, we make a module name that
11
+ # depends on the path for it to be unique using the hex-encoded hash of
12
+ # the path.
13
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
+ module_name = path_hash
15
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
16
+ if spec is None:
17
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
+ module = importlib.util.module_from_spec(spec)
19
+ if module is None:
20
+ raise ImportError(f"Cannot load module {module_name} from spec")
21
+ sys.modules[module_name] = module
22
+ spec.loader.exec_module(module) # type: ignore
23
+ return module
24
+
25
+
26
+ globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
build/torch210-cxx11-cu128-x86_64-linux/metadata.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"python-depends":[]}
build/torch210-cxx11-cu128-x86_64-linux/models/__init__.py ADDED
File without changes
build/torch210-cxx11-cu128-x86_64-linux/models/config_mamba.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+
3
+
4
+ @dataclass
5
+ class MambaConfig:
6
+
7
+ d_model: int = 2560
8
+ d_intermediate: int = 0
9
+ n_layer: int = 64
10
+ vocab_size: int = 50277
11
+ ssm_cfg: dict = field(default_factory=dict)
12
+ attn_layer_idx: list = field(default_factory=list)
13
+ attn_cfg: dict = field(default_factory=dict)
14
+ rms_norm: bool = True
15
+ residual_in_fp32: bool = True
16
+ fused_add_norm: bool = True
17
+ pad_vocab_size_multiple: int = 8
18
+ tie_embeddings: bool = True
build/torch210-cxx11-cu128-x86_64-linux/models/mixer_seq_simple.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Albert Gu, Tri Dao.
2
+
3
+ import math
4
+ from functools import partial
5
+ import json
6
+ import os
7
+ import copy
8
+
9
+ from collections import namedtuple
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+
14
+ from .config_mamba import MambaConfig
15
+ from ..modules.mamba_simple import Mamba
16
+ from ..modules.mamba2 import Mamba2
17
+ from ..modules.mha import MHA
18
+ from ..modules.mlp import GatedMLP
19
+ from ..modules.block import Block
20
+ from ..utils.generation import GenerationMixin
21
+ from ..utils.hf import load_config_hf, load_state_dict_hf
22
+
23
+ try:
24
+ from ..ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
25
+ except ImportError:
26
+ RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
27
+
28
+
29
+ def create_block(
30
+ d_model,
31
+ d_intermediate,
32
+ ssm_cfg=None,
33
+ attn_layer_idx=None,
34
+ attn_cfg=None,
35
+ norm_epsilon=1e-5,
36
+ rms_norm=False,
37
+ residual_in_fp32=False,
38
+ fused_add_norm=False,
39
+ layer_idx=None,
40
+ device=None,
41
+ dtype=None,
42
+ ):
43
+ if ssm_cfg is None:
44
+ ssm_cfg = {}
45
+ if attn_layer_idx is None:
46
+ attn_layer_idx = []
47
+ if attn_cfg is None:
48
+ attn_cfg = {}
49
+ factory_kwargs = {"device": device, "dtype": dtype}
50
+ if layer_idx not in attn_layer_idx:
51
+ # Create a copy of the config to modify
52
+ ssm_cfg = copy.deepcopy(ssm_cfg) if ssm_cfg is not None else {}
53
+ ssm_layer = ssm_cfg.pop("layer", "Mamba1")
54
+ if ssm_layer not in ["Mamba1", "Mamba2"]:
55
+ raise ValueError(f"Invalid ssm_layer: {ssm_layer}, only support Mamba1 and Mamba2")
56
+ mixer_cls = partial(
57
+ Mamba2 if ssm_layer == "Mamba2" else Mamba,
58
+ layer_idx=layer_idx,
59
+ **ssm_cfg,
60
+ **factory_kwargs
61
+ )
62
+ else:
63
+ mixer_cls = partial(MHA, layer_idx=layer_idx, **attn_cfg, **factory_kwargs)
64
+ norm_cls = partial(
65
+ nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
66
+ )
67
+ if d_intermediate == 0:
68
+ mlp_cls = nn.Identity
69
+ else:
70
+ mlp_cls = partial(
71
+ GatedMLP, hidden_features=d_intermediate, out_features=d_model, **factory_kwargs
72
+ )
73
+ block = Block(
74
+ d_model,
75
+ mixer_cls,
76
+ mlp_cls,
77
+ norm_cls=norm_cls,
78
+ fused_add_norm=fused_add_norm,
79
+ residual_in_fp32=residual_in_fp32,
80
+ )
81
+ block.layer_idx = layer_idx
82
+ return block
83
+
84
+
85
+ # https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
86
+ def _init_weights(
87
+ module,
88
+ n_layer,
89
+ initializer_range=0.02, # Now only used for embedding layer.
90
+ rescale_prenorm_residual=True,
91
+ n_residuals_per_layer=1, # Change to 2 if we have MLP
92
+ ):
93
+ if isinstance(module, nn.Linear):
94
+ if module.bias is not None:
95
+ if not getattr(module.bias, "_no_reinit", False):
96
+ nn.init.zeros_(module.bias)
97
+ elif isinstance(module, nn.Embedding):
98
+ nn.init.normal_(module.weight, std=initializer_range)
99
+
100
+ if rescale_prenorm_residual:
101
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
102
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
103
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
104
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
105
+ #
106
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
107
+ for name, p in module.named_parameters():
108
+ if name in ["out_proj.weight", "fc2.weight"]:
109
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
110
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
111
+ # We need to reinit p since this code could be called multiple times
112
+ # Having just p *= scale would repeatedly scale it down
113
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
114
+ with torch.no_grad():
115
+ p /= math.sqrt(n_residuals_per_layer * n_layer)
116
+
117
+
118
+ class MixerModel(nn.Module):
119
+ def __init__(
120
+ self,
121
+ d_model: int,
122
+ n_layer: int,
123
+ d_intermediate: int,
124
+ vocab_size: int,
125
+ ssm_cfg=None,
126
+ attn_layer_idx=None,
127
+ attn_cfg=None,
128
+ norm_epsilon: float = 1e-5,
129
+ rms_norm: bool = False,
130
+ initializer_cfg=None,
131
+ fused_add_norm=False,
132
+ residual_in_fp32=False,
133
+ device=None,
134
+ dtype=None,
135
+ ) -> None:
136
+ factory_kwargs = {"device": device, "dtype": dtype}
137
+ super().__init__()
138
+ self.residual_in_fp32 = residual_in_fp32
139
+
140
+ self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs)
141
+
142
+ # We change the order of residual and layer norm:
143
+ # Instead of LN -> Attn / MLP -> Add, we do:
144
+ # Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
145
+ # the main branch (output of MLP / Mixer). The model definition is unchanged.
146
+ # This is for performance reason: we can fuse add + layer_norm.
147
+ self.fused_add_norm = fused_add_norm
148
+ if self.fused_add_norm:
149
+ if layer_norm_fn is None or rms_norm_fn is None:
150
+ raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")
151
+
152
+ self.layers = nn.ModuleList(
153
+ [
154
+ create_block(
155
+ d_model,
156
+ d_intermediate=d_intermediate,
157
+ ssm_cfg=ssm_cfg,
158
+ attn_layer_idx=attn_layer_idx,
159
+ attn_cfg=attn_cfg,
160
+ norm_epsilon=norm_epsilon,
161
+ rms_norm=rms_norm,
162
+ residual_in_fp32=residual_in_fp32,
163
+ fused_add_norm=fused_add_norm,
164
+ layer_idx=i,
165
+ **factory_kwargs,
166
+ )
167
+ for i in range(n_layer)
168
+ ]
169
+ )
170
+
171
+ self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(
172
+ d_model, eps=norm_epsilon, **factory_kwargs
173
+ )
174
+
175
+ self.apply(
176
+ partial(
177
+ _init_weights,
178
+ n_layer=n_layer,
179
+ **(initializer_cfg if initializer_cfg is not None else {}),
180
+ n_residuals_per_layer=1 if d_intermediate == 0 else 2, # 2 if we have MLP
181
+ )
182
+ )
183
+
184
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
185
+ return {
186
+ i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
187
+ for i, layer in enumerate(self.layers)
188
+ }
189
+
190
+ def forward(self, input_ids, inference_params=None, **mixer_kwargs):
191
+ hidden_states = self.embedding(input_ids)
192
+ residual = None
193
+ for layer in self.layers:
194
+ hidden_states, residual = layer(
195
+ hidden_states, residual, inference_params=inference_params, **mixer_kwargs
196
+ )
197
+ if not self.fused_add_norm:
198
+ residual = (hidden_states + residual) if residual is not None else hidden_states
199
+ hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
200
+ else:
201
+ # Set prenorm=False here since we don't need the residual
202
+ hidden_states = layer_norm_fn(
203
+ hidden_states,
204
+ self.norm_f.weight,
205
+ self.norm_f.bias,
206
+ eps=self.norm_f.eps,
207
+ residual=residual,
208
+ prenorm=False,
209
+ residual_in_fp32=self.residual_in_fp32,
210
+ is_rms_norm=isinstance(self.norm_f, RMSNorm)
211
+ )
212
+ return hidden_states
213
+
214
+
215
+ class MambaLMHeadModel(nn.Module, GenerationMixin):
216
+
217
+ def __init__(
218
+ self,
219
+ config: MambaConfig,
220
+ initializer_cfg=None,
221
+ device=None,
222
+ dtype=None,
223
+ ) -> None:
224
+ self.config = config
225
+ d_model = config.d_model
226
+ n_layer = config.n_layer
227
+ d_intermediate = config.d_intermediate
228
+ vocab_size = config.vocab_size
229
+ ssm_cfg = config.ssm_cfg
230
+ attn_layer_idx = config.attn_layer_idx
231
+ attn_cfg = config.attn_cfg
232
+ rms_norm = config.rms_norm
233
+ residual_in_fp32 = config.residual_in_fp32
234
+ fused_add_norm = config.fused_add_norm
235
+ pad_vocab_size_multiple = config.pad_vocab_size_multiple
236
+ factory_kwargs = {"device": device, "dtype": dtype}
237
+
238
+ super().__init__()
239
+ if vocab_size % pad_vocab_size_multiple != 0:
240
+ vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)
241
+ self.backbone = MixerModel(
242
+ d_model=d_model,
243
+ n_layer=n_layer,
244
+ d_intermediate=d_intermediate,
245
+ vocab_size=vocab_size,
246
+ ssm_cfg=ssm_cfg,
247
+ attn_layer_idx=attn_layer_idx,
248
+ attn_cfg=attn_cfg,
249
+ rms_norm=rms_norm,
250
+ initializer_cfg=initializer_cfg,
251
+ fused_add_norm=fused_add_norm,
252
+ residual_in_fp32=residual_in_fp32,
253
+ **factory_kwargs,
254
+ )
255
+ self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
256
+
257
+ # Initialize weights and apply final processing
258
+ self.apply(
259
+ partial(
260
+ _init_weights,
261
+ n_layer=n_layer,
262
+ **(initializer_cfg if initializer_cfg is not None else {}),
263
+ )
264
+ )
265
+ self.tie_weights()
266
+
267
+ def tie_weights(self):
268
+ if self.config.tie_embeddings:
269
+ self.lm_head.weight = self.backbone.embedding.weight
270
+
271
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
272
+ return self.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
273
+
274
+ def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0, **mixer_kwargs):
275
+ """
276
+ "position_ids" is just to be compatible with Transformer generation. We don't use it.
277
+ num_last_tokens: if > 0, only return the logits for the last n tokens
278
+ """
279
+ hidden_states = self.backbone(input_ids, inference_params=inference_params, **mixer_kwargs)
280
+ if num_last_tokens > 0:
281
+ hidden_states = hidden_states[:, -num_last_tokens:]
282
+ lm_logits = self.lm_head(hidden_states)
283
+ CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
284
+ return CausalLMOutput(logits=lm_logits)
285
+
286
+ @classmethod
287
+ def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
288
+ config_data = load_config_hf(pretrained_model_name)
289
+ config = MambaConfig(**config_data)
290
+ model = cls(config, device=device, dtype=dtype, **kwargs)
291
+ model.load_state_dict(load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype))
292
+ return model
293
+
294
+ def save_pretrained(self, save_directory):
295
+ """
296
+ Minimal implementation of save_pretrained for MambaLMHeadModel.
297
+ Save the model and its configuration file to a directory.
298
+ """
299
+ # Ensure save_directory exists
300
+ os.makedirs(save_directory, exist_ok=True)
301
+
302
+ # Save the model's state_dict
303
+ model_path = os.path.join(save_directory, 'pytorch_model.bin')
304
+ torch.save(self.state_dict(), model_path)
305
+
306
+ # Save the configuration of the model
307
+ config_path = os.path.join(save_directory, 'config.json')
308
+ with open(config_path, 'w') as f:
309
+ json.dump(self.config.__dict__, f, indent=4)
build/torch210-cxx11-cu128-x86_64-linux/modules/__init__.py ADDED
File without changes
build/torch210-cxx11-cu128-x86_64-linux/modules/block.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao, Albert Gu.
2
+ from typing import Optional
3
+
4
+ import torch
5
+ from torch import nn, Tensor
6
+
7
+ from ..ops.triton.layer_norm import RMSNorm, layer_norm_fn
8
+
9
+
10
+ class Block(nn.Module):
11
+ def __init__(
12
+ self,
13
+ dim,
14
+ mixer_cls,
15
+ mlp_cls,
16
+ norm_cls=nn.LayerNorm,
17
+ fused_add_norm=False,
18
+ residual_in_fp32=False,
19
+ ):
20
+ """
21
+ Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection"
22
+
23
+ This Block has a slightly different structure compared to a regular
24
+ prenorm Transformer block.
25
+ The standard block is: LN -> MHA/MLP -> Add.
26
+ [Ref: https://arxiv.org/abs/2002.04745]
27
+ Here we have: Add -> LN -> Mixer, returning both
28
+ the hidden_states (output of the mixer) and the residual.
29
+ This is purely for performance reasons, as we can fuse add and LayerNorm.
30
+ The residual needs to be provided (except for the very first block).
31
+ """
32
+ super().__init__()
33
+ self.residual_in_fp32 = residual_in_fp32
34
+ self.fused_add_norm = fused_add_norm
35
+ self.norm = norm_cls(dim)
36
+ self.mixer = mixer_cls(dim)
37
+ if mlp_cls is not nn.Identity:
38
+ self.norm2 = norm_cls(dim)
39
+ self.mlp = mlp_cls(dim)
40
+ else:
41
+ self.mlp = None
42
+ if self.fused_add_norm:
43
+ assert RMSNorm is not None, "RMSNorm import fails"
44
+ assert isinstance(
45
+ self.norm, (nn.LayerNorm, RMSNorm)
46
+ ), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
47
+
48
+ def forward(
49
+ self,
50
+ hidden_states: Tensor,
51
+ residual: Optional[Tensor] = None,
52
+ inference_params=None,
53
+ **mixer_kwargs
54
+ ):
55
+ r"""Pass the input through the encoder layer.
56
+
57
+ Args:
58
+ hidden_states: the sequence to the encoder layer (required).
59
+ residual: hidden_states = Mixer(LN(residual))
60
+ """
61
+ if not self.fused_add_norm:
62
+ residual = (
63
+ (hidden_states + residual) if residual is not None else hidden_states
64
+ )
65
+ hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
66
+ if self.residual_in_fp32:
67
+ residual = residual.to(torch.float32)
68
+ else:
69
+ hidden_states, residual = layer_norm_fn(
70
+ hidden_states,
71
+ self.norm.weight,
72
+ self.norm.bias,
73
+ residual=residual,
74
+ prenorm=True,
75
+ residual_in_fp32=self.residual_in_fp32,
76
+ eps=self.norm.eps,
77
+ is_rms_norm=isinstance(self.norm, RMSNorm),
78
+ )
79
+ hidden_states = self.mixer(
80
+ hidden_states, inference_params=inference_params, **mixer_kwargs
81
+ )
82
+
83
+ if self.mlp is not None:
84
+ if not self.fused_add_norm:
85
+ residual = hidden_states + residual
86
+ hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
87
+ if self.residual_in_fp32:
88
+ residual = residual.to(torch.float32)
89
+ else:
90
+ hidden_states, residual = layer_norm_fn(
91
+ hidden_states,
92
+ self.norm2.weight,
93
+ self.norm2.bias,
94
+ residual=residual,
95
+ prenorm=True,
96
+ residual_in_fp32=self.residual_in_fp32,
97
+ eps=self.norm2.eps,
98
+ is_rms_norm=isinstance(self.norm2, RMSNorm),
99
+ )
100
+ hidden_states = self.mlp(hidden_states)
101
+
102
+ return hidden_states, residual
103
+
104
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
105
+ return self.mixer.allocate_inference_cache(
106
+ batch_size, max_seqlen, dtype=dtype, **kwargs
107
+ )
build/torch210-cxx11-cu128-x86_64-linux/modules/mamba2.py ADDED
@@ -0,0 +1,502 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao, Albert Gu.
2
+
3
+ import math
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from einops import rearrange, repeat
10
+
11
+ try:
12
+ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
13
+ except ImportError:
14
+ causal_conv1d_fn, causal_conv1d_update = None, None
15
+
16
+ try:
17
+ from causal_conv1d.causal_conv1d_varlen import causal_conv1d_varlen_states
18
+ except ImportError:
19
+ causal_conv1d_varlen_states = None
20
+
21
+ try:
22
+ from ..ops.triton.selective_state_update import selective_state_update
23
+ except ImportError:
24
+ selective_state_update = None
25
+
26
+ from ..ops.triton.layernorm_gated import RMSNorm as RMSNormGated
27
+
28
+ from ..distributed.tensor_parallel import ColumnParallelLinear, RowParallelLinear
29
+ from ..distributed.distributed_utils import all_reduce, reduce_scatter
30
+
31
+ from ..ops.triton.ssd_combined import mamba_chunk_scan_combined
32
+ from ..ops.triton.ssd_combined import mamba_split_conv1d_scan_combined
33
+
34
+ from huggingface_hub import PyTorchModelHubMixin
35
+
36
+
37
+ class Mamba2(nn.Module, PyTorchModelHubMixin):
38
+ def __init__(
39
+ self,
40
+ d_model,
41
+ d_state=128,
42
+ d_conv=4,
43
+ conv_init=None,
44
+ expand=2,
45
+ headdim=64,
46
+ d_ssm=None, # If not None, we only apply SSM on this many dimensions, the rest uses gated MLP
47
+ ngroups=1,
48
+ A_init_range=(1, 16),
49
+ D_has_hdim=False,
50
+ rmsnorm=True,
51
+ norm_before_gate=False,
52
+ dt_min=0.001,
53
+ dt_max=0.1,
54
+ dt_init_floor=1e-4,
55
+ dt_limit=(0.0, float("inf")),
56
+ bias=False,
57
+ conv_bias=True,
58
+ # Fused kernel and sharding options
59
+ chunk_size=256,
60
+ use_mem_eff_path=True,
61
+ layer_idx=None, # Absorb kwarg for general module
62
+ process_group=None,
63
+ sequence_parallel=True,
64
+ device=None,
65
+ dtype=None,
66
+ ):
67
+ factory_kwargs = {"device": device, "dtype": dtype}
68
+ super().__init__()
69
+ self.d_model = d_model
70
+ self.d_state = d_state
71
+ self.d_conv = d_conv
72
+ self.conv_init = conv_init
73
+ self.expand = expand
74
+ self.process_group = process_group
75
+ self.sequence_parallel = sequence_parallel
76
+ self.world_size = 1 if process_group is None else process_group.size()
77
+ self.local_rank = 0 if process_group is None else process_group.rank()
78
+ self.d_inner = (self.expand * self.d_model) // self.world_size
79
+ assert self.d_inner * self.world_size == self.expand * self.d_model
80
+ self.headdim = headdim
81
+ self.d_ssm = self.d_inner if d_ssm is None else d_ssm // self.world_size
82
+ assert ngroups % self.world_size == 0
83
+ self.ngroups = ngroups // self.world_size
84
+ assert self.d_ssm % self.headdim == 0
85
+ self.nheads = self.d_ssm // self.headdim
86
+ self.D_has_hdim = D_has_hdim
87
+ self.rmsnorm = rmsnorm
88
+ self.norm_before_gate = norm_before_gate
89
+ self.dt_limit = dt_limit
90
+ self.activation = "silu"
91
+ self.chunk_size = chunk_size
92
+ self.use_mem_eff_path = use_mem_eff_path
93
+ self.layer_idx = layer_idx
94
+
95
+ # Order: [z, x, B, C, dt]
96
+ d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
97
+ if self.process_group is None:
98
+ self.in_proj = nn.Linear(
99
+ self.d_model, d_in_proj, bias=bias, **factory_kwargs
100
+ )
101
+ else:
102
+ self.in_proj = ColumnParallelLinear(
103
+ self.d_model,
104
+ d_in_proj * self.world_size,
105
+ bias=bias,
106
+ process_group=self.process_group,
107
+ sequence_parallel=self.sequence_parallel,
108
+ **factory_kwargs,
109
+ )
110
+
111
+ conv_dim = self.d_ssm + 2 * self.ngroups * self.d_state
112
+ self.conv1d = nn.Conv1d(
113
+ in_channels=conv_dim,
114
+ out_channels=conv_dim,
115
+ bias=conv_bias,
116
+ kernel_size=d_conv,
117
+ groups=conv_dim,
118
+ padding=d_conv - 1,
119
+ **factory_kwargs,
120
+ )
121
+ if self.conv_init is not None:
122
+ nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init)
123
+
124
+ self.act = nn.SiLU()
125
+
126
+ # Initialize log dt bias
127
+ dt = torch.exp(
128
+ torch.rand(self.nheads, **factory_kwargs)
129
+ * (math.log(dt_max) - math.log(dt_min))
130
+ + math.log(dt_min)
131
+ )
132
+ dt = torch.clamp(dt, min=dt_init_floor)
133
+ # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
134
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
135
+ self.dt_bias = nn.Parameter(inv_dt)
136
+ # Just to be explicit. Without this we already don't put wd on dt_bias because of the check
137
+ # name.endswith("bias") in param_grouping.py
138
+ self.dt_bias._no_weight_decay = True
139
+
140
+ assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0]
141
+ A = torch.empty(self.nheads, dtype=torch.float32, device=device).uniform_(
142
+ *A_init_range
143
+ )
144
+ A_log = torch.log(A).to(dtype=dtype)
145
+ self.A_log = nn.Parameter(A_log)
146
+ self.A_log._no_weight_decay = True
147
+
148
+ # D "skip" parameter
149
+ self.D = nn.Parameter(
150
+ torch.ones(self.d_ssm if self.D_has_hdim else self.nheads, device=device)
151
+ )
152
+ self.D._no_weight_decay = True
153
+
154
+ if self.rmsnorm:
155
+ assert RMSNormGated is not None
156
+ self.norm = RMSNormGated(
157
+ self.d_ssm,
158
+ eps=1e-5,
159
+ norm_before_gate=self.norm_before_gate,
160
+ group_size=self.d_ssm // ngroups,
161
+ **factory_kwargs,
162
+ )
163
+
164
+ if self.process_group is None:
165
+ self.out_proj = nn.Linear(
166
+ self.d_inner, self.d_model, bias=bias, **factory_kwargs
167
+ )
168
+ else:
169
+ self.out_proj = RowParallelLinear(
170
+ self.d_inner * self.world_size,
171
+ self.d_model,
172
+ bias=bias,
173
+ process_group=self.process_group,
174
+ sequence_parallel=self.sequence_parallel,
175
+ **factory_kwargs,
176
+ )
177
+
178
+ def forward(
179
+ self, u, seqlen=None, seq_idx=None, cu_seqlens=None, inference_params=None
180
+ ):
181
+ """
182
+ u: (batch, seqlen, hidden_dim) if seqlen=None.
183
+ If seqlen is not None, u is (batch * seqlen, hidden_dim). This is so that when we
184
+ split u during sequence parallel, we split the batch * seqlen dimension
185
+ (in case batch is small).
186
+ Returns: same shape as u
187
+ """
188
+ seqlen_og = seqlen
189
+ if seqlen is None:
190
+ batch, seqlen, dim = u.shape
191
+ else:
192
+ batch_seqlen, dim = u.shape
193
+ batch = batch_seqlen // seqlen
194
+
195
+ conv_state, ssm_state = None, None
196
+ if inference_params is not None:
197
+ inference_batch = (
198
+ cu_seqlens.shape[0] - 1 if cu_seqlens is not None else batch
199
+ )
200
+ conv_state, ssm_state = self._get_states_from_cache(
201
+ inference_params, inference_batch
202
+ )
203
+ if inference_params.seqlen_offset > 0:
204
+ # The states are updated inplace
205
+ out, _, _ = self.step(u, conv_state, ssm_state)
206
+ return out
207
+
208
+ zxbcdt = self.in_proj(u) # (B, L, d_in_proj) or (B * L, d_in_proj)
209
+ if seqlen_og is not None:
210
+ zxbcdt = rearrange(zxbcdt, "(b l) d -> b l d", l=seqlen)
211
+ # If the model is loaded in fp16, without the .float() here, A might be -inf
212
+ A = -torch.exp(self.A_log.float()) # (nheads) or (d_inner, d_state)
213
+ dt_limit_kwargs = (
214
+ {} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit)
215
+ )
216
+ if self.use_mem_eff_path and inference_params is None:
217
+ out = mamba_split_conv1d_scan_combined(
218
+ zxbcdt,
219
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
220
+ self.conv1d.bias,
221
+ self.dt_bias,
222
+ A,
223
+ D=(
224
+ rearrange(self.D, "(h p) -> h p", p=self.headdim)
225
+ if self.D_has_hdim
226
+ else self.D
227
+ ),
228
+ chunk_size=self.chunk_size,
229
+ seq_idx=seq_idx,
230
+ activation=self.activation,
231
+ rmsnorm_weight=self.norm.weight if self.rmsnorm else None,
232
+ rmsnorm_eps=self.norm.eps if self.rmsnorm else 1e-6,
233
+ outproj_weight=self.out_proj.weight,
234
+ outproj_bias=self.out_proj.bias,
235
+ headdim=None if self.D_has_hdim else self.headdim,
236
+ ngroups=self.ngroups,
237
+ norm_before_gate=self.norm_before_gate,
238
+ **dt_limit_kwargs,
239
+ )
240
+ if seqlen_og is not None:
241
+ out = rearrange(out, "b l d -> (b l) d")
242
+ if self.process_group is not None:
243
+ reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
244
+ out = reduce_fn(out, self.process_group)
245
+ else:
246
+ d_mlp = (
247
+ zxbcdt.shape[-1]
248
+ - 2 * self.d_ssm
249
+ - 2 * self.ngroups * self.d_state
250
+ - self.nheads
251
+ ) // 2
252
+ z0, x0, z, xBC, dt = torch.split(
253
+ zxbcdt,
254
+ [
255
+ d_mlp,
256
+ d_mlp,
257
+ self.d_ssm,
258
+ self.d_ssm + 2 * self.ngroups * self.d_state,
259
+ self.nheads,
260
+ ],
261
+ dim=-1,
262
+ )
263
+ if conv_state is not None:
264
+ if cu_seqlens is None:
265
+ # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
266
+ # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
267
+ xBC_t = rearrange(xBC, "b l d -> b d l")
268
+ conv_state.copy_(
269
+ F.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0))
270
+ ) # Update state (B D W)
271
+ else:
272
+ assert (
273
+ causal_conv1d_varlen_states is not None
274
+ ), "varlen inference requires causal_conv1d package"
275
+ assert (
276
+ batch == 1
277
+ ), "varlen inference only supports batch dimension 1"
278
+ conv_varlen_states = causal_conv1d_varlen_states(
279
+ xBC.squeeze(0), cu_seqlens, state_len=conv_state.shape[-1]
280
+ )
281
+ conv_state.copy_(conv_varlen_states)
282
+ assert self.activation in ["silu", "swish"]
283
+ if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]:
284
+ assert (
285
+ seq_idx is None
286
+ ), "varlen conv1d requires the causal_conv1d package"
287
+ xBC = self.act(
288
+ self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)[
289
+ :, : -(self.d_conv - 1)
290
+ ]
291
+ ) # (B, L, self.d_ssm + 2 * ngroups * d_state)
292
+ else:
293
+ xBC = causal_conv1d_fn(
294
+ xBC.transpose(1, 2),
295
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
296
+ bias=self.conv1d.bias,
297
+ activation=self.activation,
298
+ seq_idx=seq_idx,
299
+ ).transpose(1, 2)
300
+ x, B, C = torch.split(
301
+ xBC,
302
+ [self.d_ssm, self.ngroups * self.d_state, self.ngroups * self.d_state],
303
+ dim=-1,
304
+ )
305
+ y = mamba_chunk_scan_combined(
306
+ rearrange(x, "b l (h p) -> b l h p", p=self.headdim),
307
+ dt,
308
+ A,
309
+ rearrange(B, "b l (g n) -> b l g n", g=self.ngroups),
310
+ rearrange(C, "b l (g n) -> b l g n", g=self.ngroups),
311
+ chunk_size=self.chunk_size,
312
+ D=(
313
+ rearrange(self.D, "(h p) -> h p", p=self.headdim)
314
+ if self.D_has_hdim
315
+ else self.D
316
+ ),
317
+ z=(
318
+ rearrange(z, "b l (h p) -> b l h p", p=self.headdim)
319
+ if not self.rmsnorm
320
+ else None
321
+ ),
322
+ dt_bias=self.dt_bias,
323
+ dt_softplus=True,
324
+ seq_idx=seq_idx,
325
+ cu_seqlens=cu_seqlens,
326
+ **dt_limit_kwargs,
327
+ return_final_states=ssm_state is not None,
328
+ return_varlen_states=cu_seqlens is not None
329
+ and inference_params is not None,
330
+ )
331
+ if ssm_state is not None:
332
+ y, last_state, *rest = y
333
+ if cu_seqlens is None:
334
+ ssm_state.copy_(last_state)
335
+ else:
336
+ varlen_states = rest[0]
337
+ ssm_state.copy_(varlen_states)
338
+ y = rearrange(y, "b l h p -> b l (h p)")
339
+ if self.rmsnorm:
340
+ y = self.norm(y, z)
341
+ if d_mlp > 0:
342
+ y = torch.cat([F.silu(z0) * x0, y], dim=-1)
343
+ if seqlen_og is not None:
344
+ y = rearrange(y, "b l d -> (b l) d")
345
+ out = self.out_proj(y)
346
+ return out
347
+
348
+ def step(self, hidden_states, conv_state, ssm_state):
349
+ dtype = hidden_states.dtype
350
+ assert (
351
+ hidden_states.shape[1] == 1
352
+ ), "Only support decoding with 1 token at a time for now"
353
+ zxbcdt = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
354
+ d_mlp = (
355
+ zxbcdt.shape[-1]
356
+ - 2 * self.d_ssm
357
+ - 2 * self.ngroups * self.d_state
358
+ - self.nheads
359
+ ) // 2
360
+ z0, x0, z, xBC, dt = torch.split(
361
+ zxbcdt,
362
+ [
363
+ d_mlp,
364
+ d_mlp,
365
+ self.d_ssm,
366
+ self.d_ssm + 2 * self.ngroups * self.d_state,
367
+ self.nheads,
368
+ ],
369
+ dim=-1,
370
+ )
371
+
372
+ # Conv step
373
+ if causal_conv1d_update is None:
374
+ conv_state.copy_(
375
+ torch.roll(conv_state, shifts=-1, dims=-1)
376
+ ) # Update state (B D W)
377
+ conv_state[:, :, -1] = xBC
378
+ xBC = torch.sum(
379
+ conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1
380
+ ) # (B D)
381
+ if self.conv1d.bias is not None:
382
+ xBC = xBC + self.conv1d.bias
383
+ xBC = self.act(xBC).to(dtype=dtype)
384
+ else:
385
+ xBC = causal_conv1d_update(
386
+ xBC,
387
+ conv_state,
388
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
389
+ self.conv1d.bias,
390
+ self.activation,
391
+ )
392
+
393
+ x, B, C = torch.split(
394
+ xBC,
395
+ [self.d_ssm, self.ngroups * self.d_state, self.ngroups * self.d_state],
396
+ dim=-1,
397
+ )
398
+ A = -torch.exp(self.A_log.float()) # (nheads,)
399
+
400
+ # SSM step
401
+ if selective_state_update is None:
402
+ assert (
403
+ self.ngroups == 1
404
+ ), "Only support ngroups=1 for this inference code path"
405
+ # Discretize A and B
406
+ dt = F.softplus(dt + self.dt_bias.to(dtype=dt.dtype)) # (batch, nheads)
407
+ dA = torch.exp(dt * A) # (batch, nheads)
408
+ x = rearrange(x, "b (h p) -> b h p", p=self.headdim)
409
+ dBx = torch.einsum("bh,bn,bhp->bhpn", dt, B, x)
410
+ ssm_state.copy_(ssm_state * rearrange(dA, "b h -> b h 1 1") + dBx)
411
+ y = torch.einsum("bhpn,bn->bhp", ssm_state.to(dtype), C)
412
+ y = y + rearrange(self.D.to(dtype), "h -> h 1") * x
413
+ y = rearrange(y, "b h p -> b (h p)")
414
+ if not self.rmsnorm:
415
+ y = y * self.act(z) # (B D)
416
+ else:
417
+ A = repeat(A, "h -> h p n", p=self.headdim, n=self.d_state).to(
418
+ dtype=torch.float32
419
+ )
420
+ dt = repeat(dt, "b h -> b h p", p=self.headdim)
421
+ dt_bias = repeat(self.dt_bias, "h -> h p", p=self.headdim)
422
+ D = repeat(self.D, "h -> h p", p=self.headdim)
423
+ B = rearrange(B, "b (g n) -> b g n", g=self.ngroups)
424
+ C = rearrange(C, "b (g n) -> b g n", g=self.ngroups)
425
+ x_reshaped = rearrange(x, "b (h p) -> b h p", p=self.headdim)
426
+ if not self.rmsnorm:
427
+ z = rearrange(z, "b (h p) -> b h p", p=self.headdim)
428
+ y = selective_state_update(
429
+ ssm_state,
430
+ x_reshaped,
431
+ dt,
432
+ A,
433
+ B,
434
+ C,
435
+ D,
436
+ z=z if not self.rmsnorm else None,
437
+ dt_bias=dt_bias,
438
+ dt_softplus=True,
439
+ )
440
+ y = rearrange(y, "b h p -> b (h p)")
441
+ if self.rmsnorm:
442
+ y = self.norm(y, z)
443
+ if d_mlp > 0:
444
+ y = torch.cat([F.silu(z0) * x0, y], dim=-1)
445
+ out = self.out_proj(y)
446
+ return out.unsqueeze(1), conv_state, ssm_state
447
+
448
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
449
+ device = self.out_proj.weight.device
450
+ conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
451
+ conv_state = torch.zeros(
452
+ batch_size,
453
+ self.d_conv,
454
+ self.conv1d.weight.shape[0],
455
+ device=device,
456
+ dtype=conv_dtype,
457
+ ).transpose(1, 2)
458
+ ssm_dtype = self.in_proj.weight.dtype if dtype is None else dtype
459
+ ssm_state = torch.zeros(
460
+ batch_size,
461
+ self.nheads,
462
+ self.headdim,
463
+ self.d_state,
464
+ device=device,
465
+ dtype=ssm_dtype,
466
+ )
467
+ return conv_state, ssm_state
468
+
469
+ def _get_states_from_cache(
470
+ self, inference_params, batch_size, initialize_states=False
471
+ ):
472
+ assert self.layer_idx is not None
473
+ if self.layer_idx not in inference_params.key_value_memory_dict:
474
+ batch_shape = (batch_size,)
475
+ conv_state = torch.zeros(
476
+ batch_size,
477
+ self.d_conv,
478
+ self.conv1d.weight.shape[0],
479
+ device=self.conv1d.weight.device,
480
+ dtype=self.conv1d.weight.dtype,
481
+ ).transpose(1, 2)
482
+ ssm_state = torch.zeros(
483
+ batch_size,
484
+ self.nheads,
485
+ self.headdim,
486
+ self.d_state,
487
+ device=self.in_proj.weight.device,
488
+ dtype=self.in_proj.weight.dtype,
489
+ )
490
+ inference_params.key_value_memory_dict[self.layer_idx] = (
491
+ conv_state,
492
+ ssm_state,
493
+ )
494
+ else:
495
+ conv_state, ssm_state = inference_params.key_value_memory_dict[
496
+ self.layer_idx
497
+ ]
498
+ # TODO: What if batch size changes between generation, and we reuse the same states?
499
+ if initialize_states:
500
+ conv_state.zero_()
501
+ ssm_state.zero_()
502
+ return conv_state, ssm_state