| | import math |
| | from dataclasses import dataclass |
| | from typing import Union |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| | from pscan import pscan |
| |
|
| |
|
| |
|
| | @dataclass |
| | class MambaConfig: |
| | d_model: int |
| | n_layers: int |
| | dt_rank: Union[int, str] = 'auto' |
| | d_state: int = 16 |
| | expand_factor: int = 2 |
| | d_conv: int = 4 |
| |
|
| | dt_min: float = 0.001 |
| | dt_max: float = 0.1 |
| | dt_init: str = "random" |
| | dt_scale: float = 1.0 |
| | dt_init_floor = 1e-4 |
| |
|
| | bias: bool = False |
| | conv_bias: bool = True |
| |
|
| | pscan: bool = True |
| |
|
| | def __post_init__(self): |
| | self.d_inner = self.expand_factor * self.d_model |
| |
|
| | if self.dt_rank == 'auto': |
| | self.dt_rank = math.ceil(self.d_model / 16) |
| |
|
| | class Mamba(nn.Module): |
| | def __init__(self, config: MambaConfig): |
| | super().__init__() |
| |
|
| | self.config = config |
| |
|
| | self.layers = nn.ModuleList([ResidualBlock(config) for _ in range(config.n_layers)]) |
| | |
| |
|
| | def forward(self, x): |
| | |
| |
|
| | for layer in self.layers: |
| | x = layer(x) |
| |
|
| | |
| |
|
| | return x |
| | |
| | def step(self, x, caches): |
| | |
| |
|
| | for i, layer in enumerate(self.layers): |
| | x, caches[i] = layer.step(x, caches[i]) |
| |
|
| | return x, caches |
| |
|
| | class ResidualBlock(nn.Module): |
| | def __init__(self, config: MambaConfig): |
| | super().__init__() |
| |
|
| | self.mixer = MambaBlock(config) |
| | self.norm = RMSNorm(config.d_model) |
| |
|
| | def forward(self, x): |
| | |
| |
|
| | output = self.mixer(self.norm(x)) + x |
| | return output |
| | |
| | def step(self, x, cache): |
| | |
| |
|
| | output, cache = self.mixer.step(self.norm(x), cache) |
| | output = output + x |
| | return output, cache |
| |
|
| | class MambaBlock(nn.Module): |
| | def __init__(self, config: MambaConfig): |
| | super().__init__() |
| |
|
| | self.config = config |
| |
|
| | |
| | self.in_proj = nn.Linear(config.d_model, 2 * config.d_inner, bias=config.bias) |
| |
|
| | self.conv1d = nn.Conv1d(in_channels=config.d_inner, out_channels=config.d_inner, |
| | kernel_size=config.d_conv, bias=config.conv_bias, |
| | groups=config.d_inner, |
| | padding=config.d_conv - 1) |
| | |
| | |
| | self.x_proj = nn.Linear(config.d_inner, config.dt_rank + 2 * config.d_state, bias=False) |
| |
|
| | |
| | self.dt_proj = nn.Linear(config.dt_rank, config.d_inner, bias=True) |
| |
|
| | |
| | dt_init_std = config.dt_rank**-0.5 * config.dt_scale |
| | if config.dt_init == "constant": |
| | nn.init.constant_(self.dt_proj.weight, dt_init_std) |
| | elif config.dt_init == "random": |
| | nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std) |
| | else: |
| | raise NotImplementedError |
| | |
| | |
| | dt = torch.exp( |
| | torch.rand(config.d_inner) * (math.log(config.dt_max) - math.log(config.dt_min)) + math.log(config.dt_min) |
| | ).clamp(min=config.dt_init_floor) |
| | inv_dt = dt + torch.log(-torch.expm1(-dt)) |
| | with torch.no_grad(): |
| | self.dt_proj.bias.copy_(inv_dt) |
| | |
| | A = torch.arange(1, config.d_state + 1, dtype=torch.float32).repeat(config.d_inner, 1) |
| | self.A_log = nn.Parameter(torch.log(A)) |
| | self.D = nn.Parameter(torch.ones(config.d_inner)) |
| |
|
| | |
| | self.out_proj = nn.Linear(config.d_inner, config.d_model, bias=config.bias) |
| |
|
| | def forward(self, x): |
| | |
| |
|
| | _, L, _ = x.shape |
| |
|
| | xz = self.in_proj(x) |
| | x, z = xz.chunk(2, dim=-1) |
| |
|
| | |
| | x = x.transpose(1, 2) |
| | x = self.conv1d(x)[:, :, :L] |
| | x = x.transpose(1, 2) |
| |
|
| | x = F.silu(x) |
| | y = self.ssm(x) |
| |
|
| | |
| | z = F.silu(z) |
| |
|
| | output = y * z |
| | output = self.out_proj(output) |
| |
|
| | return output |
| | |
| | def ssm(self, x): |
| | |
| |
|
| | A = -torch.exp(self.A_log.float()) |
| | D = self.D.float() |
| | |
| |
|
| | deltaBC = self.x_proj(x) |
| |
|
| | delta, B, C = torch.split(deltaBC, [self.config.dt_rank, self.config.d_state, self.config.d_state], dim=-1) |
| | delta = F.softplus(self.dt_proj(delta)) |
| |
|
| | if self.config.pscan: |
| | y = self.selective_scan(x, delta, A, B, C, D) |
| | else: |
| | y = self.selective_scan_seq(x, delta, A, B, C, D) |
| |
|
| | return y |
| | |
| | def selective_scan(self, x, delta, A, B, C, D): |
| | |
| |
|
| | deltaA = torch.exp(delta.unsqueeze(-1) * A) |
| | deltaB = delta.unsqueeze(-1) * B.unsqueeze(2) |
| |
|
| | BX = deltaB * (x.unsqueeze(-1)) |
| | |
| | hs = pscan(deltaA, BX) |
| |
|
| | y = (hs @ C.unsqueeze(-1)).squeeze(3) |
| |
|
| | y = y + D * x |
| |
|
| | return y |
| | |
| | def selective_scan_seq(self, x, delta, A, B, C, D): |
| | |
| |
|
| | _, L, _ = x.shape |
| |
|
| | deltaA = torch.exp(delta.unsqueeze(-1) * A) |
| | deltaB = delta.unsqueeze(-1) * B.unsqueeze(2) |
| |
|
| | BX = deltaB * (x.unsqueeze(-1)) |
| |
|
| | h = torch.zeros(x.size(0), self.config.d_inner, self.config.d_state, device=deltaA.device) |
| | hs = [] |
| |
|
| | for t in range(0, L): |
| | h = deltaA[:, t] * h + BX[:, t] |
| | hs.append(h) |
| | |
| | hs = torch.stack(hs, dim=1) |
| |
|
| | y = (hs @ C.unsqueeze(-1)).squeeze(3) |
| |
|
| | y = y + D * x |
| |
|
| | return y |
| | |
| | |
| | |
| | |
| | def step(self, x, cache): |
| | |
| | |
| | h, inputs = cache |
| | |
| | xz = self.in_proj(x) |
| | x, z = xz.chunk(2, dim=1) |
| |
|
| | |
| | x_cache = x.unsqueeze(2) |
| | x = self.conv1d(torch.cat([inputs, x_cache], dim=2))[:, :, self.config.d_conv-1] |
| |
|
| | x = F.silu(x) |
| | y, h = self.ssm_step(x, h) |
| |
|
| | |
| | z = F.silu(z) |
| |
|
| | output = y * z |
| | output = self.out_proj(output) |
| |
|
| | |
| | inputs = torch.cat([inputs[:, :, 1:], x_cache], dim=2) |
| | cache = (h, inputs) |
| | |
| | return output, cache |
| |
|
| | def ssm_step(self, x, h): |
| | |
| |
|
| | A = -torch.exp(self.A_log.float()) |
| | D = self.D.float() |
| | |
| |
|
| | deltaBC = self.x_proj(x) |
| |
|
| | delta, B, C = torch.split(deltaBC, [self.config.dt_rank, self.config.d_state, self.config.d_state], dim=-1) |
| | delta = F.softplus(self.dt_proj(delta)) |
| |
|
| | deltaA = torch.exp(delta.unsqueeze(-1) * A) |
| | deltaB = delta.unsqueeze(-1) * B.unsqueeze(1) |
| |
|
| | BX = deltaB * (x.unsqueeze(-1)) |
| |
|
| | if h is None: |
| | h = torch.zeros(x.size(0), self.config.d_inner, self.config.d_state, device=deltaA.device) |
| |
|
| | h = deltaA * h + BX |
| |
|
| | y = (h @ C.unsqueeze(-1)).squeeze(2) |
| |
|
| | y = y + D * x |
| |
|
| | |
| | return y, h.squeeze(1) |
| |
|
| |
|
| | class RMSNorm(nn.Module): |
| | def __init__(self, d_model: int, eps: float = 1e-5): |
| | super().__init__() |
| |
|
| | self.eps = eps |
| | self.weight = nn.Parameter(torch.ones(d_model)) |
| |
|
| | def forward(self, x): |
| | output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight |
| |
|
| | return output |
| |
|